From 88d583bfb1de55e7dce580759329d2f8dea1810c Mon Sep 17 00:00:00 2001 From: Arvindh <30824765+arvindh123@users.noreply.github.com> Date: Mon, 20 Jan 2025 17:06:50 +0530 Subject: [PATCH] SMQ-2605 - Groups replication with groups events consumer & listing of things and channels (#2639) Signed-off-by: Arvindh --- api/http/common.go | 1 + channels/api/http/decode.go | 108 ++-- channels/api/http/endpoints.go | 24 +- channels/api/http/requests.go | 30 +- channels/api/http/requests_test.go | 8 - channels/channels.go | 68 ++- channels/events/events.go | 57 +-- channels/events/streams.go | 8 +- channels/middleware/authorization.go | 10 +- channels/middleware/logging.go | 10 +- channels/middleware/metrics.go | 8 +- channels/mocks/repository.go | 28 ++ channels/mocks/service.go | 14 +- channels/postgres/channels.go | 454 +++++++++++++++-- channels/postgres/init.go | 9 + channels/roleoperations.go | 2 +- channels/service.go | 78 +-- channels/service_test.go | 307 +++++------- channels/tracing/tracing.go | 6 +- cli/users.go | 26 - clients/api/http/decode.go | 112 +++-- clients/api/http/endpoints.go | 38 +- clients/api/http/endpoints_test.go | 2 +- clients/api/http/requests.go | 26 +- clients/api/http/requests_test.go | 8 - clients/clients.go | 66 ++- clients/events/events.go | 52 +- clients/events/streams.go | 26 +- clients/middleware/authorization.go | 28 +- clients/middleware/logging.go | 26 +- clients/middleware/metrics.go | 12 +- clients/mocks/repository.go | 56 +-- clients/mocks/service.go | 40 +- clients/postgres/clients.go | 472 +++++++++++++++--- clients/postgres/clients_test.go | 269 +--------- clients/postgres/init.go | 8 + clients/roleoperations.go | 2 +- clients/service.go | 114 +---- clients/service_test.go | 123 +---- clients/status_test.go | 2 +- clients/tracing/tracing.go | 10 +- cmd/channels/main.go | 12 + cmd/clients/main.go | 12 + groups/events/events.go | 14 +- groups/groups.go | 3 +- groups/middleware/authorization.go | 1 + pkg/errors/repository/types.go | 3 +- pkg/events/events.go | 1 + pkg/events/nats/subscriber.go | 6 +- pkg/groups/events/consumer/decode.go | 255 ++++++++++ pkg/groups/events/consumer/doc.go | 6 + pkg/groups/events/consumer/streams.go | 253 ++++++++++ pkg/groups/events/doc.go | 6 + pkg/messaging/nats/pubsub.go | 15 +- pkg/messaging/pubsub.go | 2 + pkg/roles/repo/postgres/init.go | 18 +- pkg/roles/repo/postgres/roles.go | 2 +- .../rolemanager/events/consumer/decode.go | 146 ++++++ .../rolemanager/events/consumer/handler.go | 188 +++++++ pkg/roles/rolemanager/events/streams.go | 7 +- pkg/sdk/channels_test.go | 111 ++-- pkg/sdk/clients.go | 17 - pkg/sdk/clients_test.go | 319 ++---------- pkg/sdk/mocks/sdk.go | 61 --- pkg/sdk/sdk.go | 11 - pkg/sdk/setup_test.go | 1 - 66 files changed, 2643 insertions(+), 1575 deletions(-) create mode 100644 pkg/groups/events/consumer/decode.go create mode 100644 pkg/groups/events/consumer/doc.go create mode 100644 pkg/groups/events/consumer/streams.go create mode 100644 pkg/groups/events/doc.go create mode 100644 pkg/roles/rolemanager/events/consumer/decode.go create mode 100644 pkg/roles/rolemanager/events/consumer/handler.go diff --git a/api/http/common.go b/api/http/common.go index b326abecc5..08b8e64ef1 100644 --- a/api/http/common.go +++ b/api/http/common.go @@ -58,6 +58,7 @@ const ( UserKey = "user" DomainKey = "domain" ChannelKey = "channel" + ConnTypeKey = "connection_type" DefPermission = "read_permission" DefTotal = uint64(100) DefOffset = 0 diff --git a/channels/api/http/decode.go b/channels/api/http/decode.go index ea662f8da0..e2ce4ceac1 100644 --- a/channels/api/http/decode.go +++ b/channels/api/http/decode.go @@ -11,7 +11,7 @@ import ( api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" - smqclients "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/clients" "github.com/absmach/supermq/pkg/errors" "github.com/go-chi/chi/v5" ) @@ -51,58 +51,106 @@ func decodeCreateChannelsReq(_ context.Context, r *http.Request) (interface{}, e } func decodeListChannels(_ context.Context, r *http.Request) (interface{}, error) { - s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus) + name, err := apiutil.ReadStringQuery(r, api.NameKey, "") if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + + tag, err := apiutil.ReadStringQuery(r, api.TagKey, "") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + + s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil) + status, err := clients.ToStatus(s) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - n, err := apiutil.ReadStringQuery(r, api.NameKey, "") + + meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - t, err := apiutil.ReadStringQuery(r, api.TagKey, "") + + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - id, err := apiutil.ReadStringQuery(r, api.IDOrder, "") + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission) + + dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - lp, err := apiutil.ReadBoolQuery(r, api.ListPerms, api.DefListPerms) + order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - st, err := smqclients.ToStatus(s) + + allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "") if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + actions := []string{} + + allActions = strings.TrimSpace(allActions) + if allActions != "" { + actions = strings.Split(allActions, ",") } + roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "") + if err != nil { + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "") + if err != nil { + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "") + if err != nil { + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + userID, err := apiutil.ReadStringQuery(r, api.UserKey, "") + if err != nil { + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "") + if err != nil { + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + clientID, err := apiutil.ReadStringQuery(r, api.ClientKey, "") + if err != nil { + return listChannelsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + req := listChannelsReq{ - status: st, - offset: o, - limit: l, - metadata: m, - name: n, - tag: t, - permission: p, - listPerms: lp, - userID: chi.URLParam(r, "userID"), - id: id, + name: name, + tag: tag, + status: status, + metadata: meta, + roleName: roleName, + roleID: roleID, + actions: actions, + accessType: accessType, + order: order, + dir: dir, + offset: offset, + limit: limit, + groupID: groupID, + clientID: clientID, + userID: userID, } return req, nil } diff --git a/channels/api/http/endpoints.go b/channels/api/http/endpoints.go index 62f5500499..503f32911f 100644 --- a/channels/api/http/endpoints.go +++ b/channels/api/http/endpoints.go @@ -104,15 +104,21 @@ func listChannelsEndpoint(svc channels.Service) endpoint.Endpoint { } pm := channels.PageMetadata{ - Status: req.status, - Offset: req.offset, - Limit: req.limit, - Name: req.name, - Tag: req.tag, - Permission: req.permission, - Metadata: req.metadata, - ListPerms: req.listPerms, - Id: req.id, + Offset: req.offset, + Limit: req.limit, + Name: req.name, + Order: req.order, + Dir: req.dir, + Metadata: req.metadata, + Tag: req.tag, + Status: req.status, + Group: req.groupID, + Client: req.clientID, + ConnectionType: req.connType, + RoleName: req.roleName, + RoleID: req.roleID, + Actions: req.actions, + AccessType: req.accessType, } page, err := svc.ListChannels(ctx, session, pm) if err != nil { diff --git a/channels/api/http/requests.go b/channels/api/http/requests.go index cd5b64e2ce..2d27d16fb9 100644 --- a/channels/api/http/requests.go +++ b/channels/api/http/requests.go @@ -9,7 +9,7 @@ import ( api "github.com/absmach/supermq/api/http" apiutil "github.com/absmach/supermq/api/http/util" "github.com/absmach/supermq/channels" - smqclients "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/clients" "github.com/absmach/supermq/pkg/connections" ) @@ -64,29 +64,29 @@ func (req viewChannelReq) validate() error { } type listChannelsReq struct { - status smqclients.Status - offset uint64 - limit uint64 name string tag string - permission string - visibility string + status clients.Status + metadata clients.Metadata + roleName string + roleID string + actions []string + accessType string + order string + dir string + offset uint64 + limit uint64 + groupID string + clientID string + connType string userID string - listPerms bool - metadata smqclients.Metadata - id string } func (req listChannelsReq) validate() error { if req.limit > api.MaxLimitSize || req.limit < 1 { return apiutil.ErrLimitSize } - if req.visibility != "" && - req.visibility != api.AllVisibility && - req.visibility != api.MyVisibility && - req.visibility != api.SharedVisibility { - return apiutil.ErrInvalidVisibilityType - } + if len(req.name) > api.MaxNameSize { return apiutil.ErrNameSize } diff --git a/channels/api/http/requests_test.go b/channels/api/http/requests_test.go index 42a91d0d03..a049354adc 100644 --- a/channels/api/http/requests_test.go +++ b/channels/api/http/requests_test.go @@ -174,14 +174,6 @@ func TestListChannelsReqValidation(t *testing.T) { }, err: apiutil.ErrNameSize, }, - { - desc: "invalid visibility", - req: listChannelsReq{ - limit: 10, - visibility: "invalid", - }, - err: apiutil.ErrInvalidVisibilityType, - }, } for _, tc := range cases { err := tc.req.validate() diff --git a/channels/channels.go b/channels/channels.go index 1bdebad9de..ecf019edb7 100644 --- a/channels/channels.go +++ b/channels/channels.go @@ -22,29 +22,43 @@ type Channel struct { ParentGroup string `json:"parent_group_id,omitempty"` Domain string `json:"domain_id,omitempty"` Metadata clients.Metadata `json:"metadata,omitempty"` + CreatedBy string `json:"created_by,omitempty"` CreatedAt time.Time `json:"created_at,omitempty"` UpdatedAt time.Time `json:"updated_at,omitempty"` UpdatedBy string `json:"updated_by,omitempty"` - Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled - Permissions []string `json:"permissions,omitempty"` // 1 for enabled, 0 for disabled + Status clients.Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled + // Extended + ParentGroupPath string `json:"parent_group_path"` + RoleID string `json:"role_id"` + RoleName string `json:"role_name"` + Actions []string `json:"actions"` + AccessType string `json:"access_type"` + AccessProviderId string `json:"access_provider_id"` + AccessProviderRoleId string `json:"access_provider_role_id"` + AccessProviderRoleName string `json:"access_provider_role_name"` + AccessProviderRoleActions []string `json:"access_provider_role_actions"` } type PageMetadata struct { - Total uint64 `json:"total"` - Offset uint64 `json:"offset"` - Limit uint64 `json:"limit"` - Name string `json:"name,omitempty"` - Id string `json:"id,omitempty"` - Order string `json:"order,omitempty"` - Dir string `json:"dir,omitempty"` - Metadata clients.Metadata `json:"metadata,omitempty"` - Domain string `json:"domain,omitempty"` - Tag string `json:"tag,omitempty"` - Permission string `json:"permission,omitempty"` - Status clients.Status `json:"status,omitempty"` - IDs []string `json:"ids,omitempty"` - ListPerms bool `json:"-"` - ClientID string `json:"-"` + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Order string `json:"order,omitempty"` + Dir string `json:"dir,omitempty"` + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Metadata clients.Metadata `json:"metadata,omitempty"` + Domain string `json:"domain,omitempty"` + Tag string `json:"tag,omitempty"` + Status clients.Status `json:"status,omitempty"` + Group string `json:"group,omitempty"` + Client string `json:"client,omitempty"` + ConnectionType string `json:"connection_type,omitempty"` + RoleName string `json:"role_name,omitempty"` + RoleID string `json:"role_id,omitempty"` + Actions []string `json:"actions,omitempty"` + AccessType string `json:"access_type,omitempty"` + IDs []string `json:"-"` } // ChannelsPage contains page related metadata as well as list of channels that @@ -71,15 +85,15 @@ type AuthzReq struct { //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" type Service interface { - // CreateChannels adds channels to the user identified by the provided key. + // CreateChannels adds channels to the user. CreateChannels(ctx context.Context, session authn.Session, channels ...Channel) ([]Channel, []roles.RoleProvision, error) // ViewChannel retrieves data about the channel identified by the provided - // ID, that belongs to the user identified by the provided key. + // ID, that belongs to the user. ViewChannel(ctx context.Context, session authn.Session, id string) (Channel, error) // UpdateChannel updates the channel identified by the provided ID, that - // belongs to the user identified by the provided key. + // belongs to the user. UpdateChannel(ctx context.Context, session authn.Session, channel Channel) (Channel, error) // UpdateChannelTags updates the channel's tags. @@ -89,17 +103,14 @@ type Service interface { DisableChannel(ctx context.Context, session authn.Session, id string) (Channel, error) - // ListChannels retrieves data about subset of channels that belongs to the - // user identified by the provided key. + // ListChannels retrieves data about subset of channels that belongs to the user. ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error) - // ListChannelsByClient retrieves data about subset of channels that have - // specified client connected or not connected to them and belong to the user identified by - // the provided key. - ListChannelsByClient(ctx context.Context, session authn.Session, id string, pm PageMetadata) (Page, error) + // ListUserChannels retrieves data about subset of channels that belong to the specified user. + ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error) // RemoveChannel removes the client identified by the provided ID, that - // belongs to the user identified by the provided key. + // belongs to the user. RemoveChannel(ctx context.Context, session authn.Session, id string) error // Connect adds clients to the channels list of connected clients. @@ -131,6 +142,9 @@ type Repository interface { ChangeStatus(ctx context.Context, channel Channel) (Channel, error) + // RetrieveUserChannels retrieves the channel of given domainID and userID. + RetrieveUserChannels(ctx context.Context, domainID, userID string, pm PageMetadata) (Page, error) + // RetrieveByID retrieves the channel having the provided identifier RetrieveByID(ctx context.Context, id string) (Channel, error) diff --git a/channels/events/events.go b/channels/events/events.go index 237bf84119..bcfa3545e3 100644 --- a/channels/events/events.go +++ b/channels/events/events.go @@ -206,9 +206,6 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) { if lce.Tag != "" { val["tag"] = lce.Tag } - if lce.Permission != "" { - val["permission"] = lce.Permission - } if lce.Status.String() != "" { val["status"] = lce.Status.String() } @@ -219,48 +216,48 @@ func (lce listChannelEvent) Encode() (map[string]interface{}, error) { return val, nil } -type listChannelByClientEvent struct { - clientID string +type listUserChannelsEvent struct { + userID string channels.PageMetadata authn.Session } -func (lcte listChannelByClientEvent) Encode() (map[string]interface{}, error) { +func (luce listUserChannelsEvent) Encode() (map[string]interface{}, error) { val := map[string]interface{}{ "operation": channelList, - "client_id": lcte.clientID, - "total": lcte.Total, - "offset": lcte.Offset, - "limit": lcte.Limit, - "domain": lcte.DomainID, - "user_id": lcte.UserID, - "token_type": lcte.Type.String(), - "super_admin": lcte.SuperAdmin, + "req_user_id": luce.userID, + "total": luce.Total, + "offset": luce.Offset, + "limit": luce.Limit, + "domain": luce.DomainID, + "user_id": luce.UserID, + "token_type": luce.Type.String(), + "super_admin": luce.SuperAdmin, } - if lcte.Name != "" { - val["name"] = lcte.Name + if luce.Name != "" { + val["name"] = luce.Name } - if lcte.Order != "" { - val["order"] = lcte.Order + if luce.Order != "" { + val["order"] = luce.Order } - if lcte.Dir != "" { - val["dir"] = lcte.Dir + if luce.Dir != "" { + val["dir"] = luce.Dir } - if lcte.Metadata != nil { - val["metadata"] = lcte.Metadata + if luce.Metadata != nil { + val["metadata"] = luce.Metadata } - if lcte.Tag != "" { - val["tag"] = lcte.Tag + if luce.Domain != "" { + val["domain"] = luce.Domain } - if lcte.Permission != "" { - val["permission"] = lcte.Permission + if luce.Tag != "" { + val["tag"] = luce.Tag } - if lcte.Status.String() != "" { - val["status"] = lcte.Status.String() + if luce.Status.String() != "" { + val["status"] = luce.Status.String() } - if len(lcte.IDs) > 0 { - val["ids"] = lcte.IDs + if len(luce.IDs) > 0 { + val["ids"] = luce.IDs } return val, nil diff --git a/channels/events/streams.go b/channels/events/streams.go index 4422d27997..7a7c4a9e2b 100644 --- a/channels/events/streams.go +++ b/channels/events/streams.go @@ -126,13 +126,13 @@ func (es *eventStore) ListChannels(ctx context.Context, session authn.Session, p return cp, nil } -func (es *eventStore) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) { - cp, err := es.svc.ListChannelsByClient(ctx, session, clientID, pm) +func (es *eventStore) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) { + cp, err := es.svc.ListUserChannels(ctx, session, userID, pm) if err != nil { return cp, err } - event := listChannelByClientEvent{ - clientID: clientID, + event := listUserChannelsEvent{ + userID: userID, PageMetadata: pm, Session: session, } diff --git a/channels/middleware/authorization.go b/channels/middleware/authorization.go index 43d50f417f..ebf0ff19ec 100644 --- a/channels/middleware/authorization.go +++ b/channels/middleware/authorization.go @@ -22,6 +22,7 @@ import ( var ( errView = errors.New("not authorized to view channel") + errList = errors.New("not authorized to list user channels") errUpdate = errors.New("not authorized to update channel") errUpdateTags = errors.New("not authorized to update channel tags") errEnable = errors.New("not authorized to enable channel") @@ -164,13 +165,13 @@ func (am *authorizationMiddleware) ListChannels(ctx context.Context, session aut } } - if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { + if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } return am.svc.ListChannels(ctx, session, pm) } -func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) { +func (am *authorizationMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ UserID: session.UserID, @@ -184,7 +185,10 @@ func (am *authorizationMiddleware) ListChannelsByClient(ctx context.Context, ses return channels.Page{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) } } - return am.svc.ListChannelsByClient(ctx, session, clientID, pm) + if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { + return channels.Page{}, errors.Wrap(err, errList) + } + return am.svc.ListUserChannels(ctx, session, userID, pm) } func (am *authorizationMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) { diff --git a/channels/middleware/logging.go b/channels/middleware/logging.go index 5decb03731..5dc97c8a6a 100644 --- a/channels/middleware/logging.go +++ b/channels/middleware/logging.go @@ -82,11 +82,11 @@ func (lm *loggingMiddleware) ListChannels(ctx context.Context, session authn.Ses return lm.svc.ListChannels(ctx, session, pm) } -func (lm *loggingMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (cp channels.Page, err error) { +func (lm *loggingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (cp channels.Page, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), - slog.String("client_id", clientID), + slog.String("user_id", userID), slog.Group("page", slog.Uint64("limit", pm.Limit), slog.Uint64("offset", pm.Offset), @@ -95,12 +95,12 @@ func (lm *loggingMiddleware) ListChannelsByClient(ctx context.Context, session a } if err != nil { args = append(args, slog.String("error", err.Error())) - lm.logger.Warn("List channels by client failed", args...) + lm.logger.Warn("List user channels failed", args...) return } - lm.logger.Info("List channels by client completed successfully", args...) + lm.logger.Info("List user channels completed successfully", args...) }(time.Now()) - return lm.svc.ListChannelsByClient(ctx, session, clientID, pm) + return lm.svc.ListUserChannels(ctx, session, userID, pm) } func (lm *loggingMiddleware) UpdateChannel(ctx context.Context, session authn.Session, client channels.Channel) (c channels.Channel, err error) { diff --git a/channels/middleware/metrics.go b/channels/middleware/metrics.go index e41447bdaf..6c12bbc9c5 100644 --- a/channels/middleware/metrics.go +++ b/channels/middleware/metrics.go @@ -58,12 +58,12 @@ func (ms *metricsMiddleware) ListChannels(ctx context.Context, session authn.Ses return ms.svc.ListChannels(ctx, session, pm) } -func (ms *metricsMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) { +func (ms *metricsMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) { defer func(begin time.Time) { - ms.counter.With("method", "list_channels_by_client").Add(1) - ms.latency.With("method", "list_channels_by_client").Observe(time.Since(begin).Seconds()) + ms.counter.With("method", "list_user_channels").Add(1) + ms.latency.With("method", "list_user_channels").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ListChannelsByClient(ctx, session, clientID, pm) + return ms.svc.ListUserChannels(ctx, session, userID, pm) } func (ms *metricsMiddleware) UpdateChannel(ctx context.Context, session authn.Session, channel channels.Channel) (channels.Channel, error) { diff --git a/channels/mocks/repository.go b/channels/mocks/repository.go index 69b7cd32bd..b22ab4dcce 100644 --- a/channels/mocks/repository.go +++ b/channels/mocks/repository.go @@ -529,6 +529,34 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro return r0, r1 } +// RetrieveUserChannels provides a mock function with given fields: ctx, domainID, userID, pm +func (_m *Repository) RetrieveUserChannels(ctx context.Context, domainID string, userID string, pm channels.PageMetadata) (channels.Page, error) { + ret := _m.Called(ctx, domainID, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for RetrieveUserChannels") + } + + var r0 channels.Page + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) (channels.Page, error)); ok { + return rf(ctx, domainID, userID, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, channels.PageMetadata) channels.Page); ok { + r0 = rf(ctx, domainID, userID, pm) + } else { + r0 = ret.Get(0).(channels.Page) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, channels.PageMetadata) error); ok { + r1 = rf(ctx, domainID, userID, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RoleAddActions provides a mock function with given fields: ctx, role, actions func (_m *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) { ret := _m.Called(ctx, role, actions) diff --git a/channels/mocks/service.go b/channels/mocks/service.go index 52fadd25fd..0aa25aca3f 100644 --- a/channels/mocks/service.go +++ b/channels/mocks/service.go @@ -246,27 +246,27 @@ func (_m *Service) ListChannels(ctx context.Context, session authn.Session, pm c return r0, r1 } -// ListChannelsByClient provides a mock function with given fields: ctx, session, id, pm -func (_m *Service) ListChannelsByClient(ctx context.Context, session authn.Session, id string, pm channels.PageMetadata) (channels.Page, error) { - ret := _m.Called(ctx, session, id, pm) +// ListUserChannels provides a mock function with given fields: ctx, session, userID, pm +func (_m *Service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) { + ret := _m.Called(ctx, session, userID, pm) if len(ret) == 0 { - panic("no return value specified for ListChannelsByClient") + panic("no return value specified for ListUserChannels") } var r0 channels.Page var r1 error if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) (channels.Page, error)); ok { - return rf(ctx, session, id, pm) + return rf(ctx, session, userID, pm) } if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, channels.PageMetadata) channels.Page); ok { - r0 = rf(ctx, session, id, pm) + r0 = rf(ctx, session, userID, pm) } else { r0 = ret.Get(0).(channels.Page) } if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, channels.PageMetadata) error); ok { - r1 = rf(ctx, session, id, pm) + r1 = rf(ctx, session, userID, pm) } else { r1 = ret.Error(1) } diff --git a/channels/postgres/channels.go b/channels/postgres/channels.go index 0905ea00fd..1264561dbb 100644 --- a/channels/postgres/channels.go +++ b/channels/postgres/channels.go @@ -21,6 +21,7 @@ import ( "github.com/absmach/supermq/pkg/postgres" rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" "github.com/jackc/pgtype" + "github.com/lib/pq" ) const ( @@ -152,7 +153,7 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe query = applyOrdering(query, pm) q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status, - c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) + c.created_by, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c %s LIMIT :limit OFFSET :offset;`, query) dbPage, err := toDBChannelsPage(pm) if err != nil { @@ -196,6 +197,303 @@ func (cr *channelRepository) RetrieveAll(ctx context.Context, pm channels.PageMe return page, nil } +func (repo *channelRepository) RetrieveUserChannels(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) { + return repo.retrieveClients(ctx, domainID, userID, pm) +} + +func (repo *channelRepository) retrieveClients(ctx context.Context, domainID, userID string, pm channels.PageMetadata) (channels.Page, error) { + pageQuery, err := PageQuery(pm) + if err != nil { + return channels.Page{}, err + } + + bq := repo.userChannelsBaseQuery(domainID, userID) + + connJoinQuery := "" + if pm.Client != "" { + connJoinQuery = "JOIN connection conn ON conn.channel_id = c.id" + } + + q := fmt.Sprintf(` + %s + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.tags, + c.metadata, + c.created_by, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + c.parent_group_path, + c.role_id, + c.role_name, + c.actions, + c.access_type, + c.access_provider_id, + c.access_provider_role_id, + c.access_provider_role_name, + c.access_provider_role_actions + FROM + final_channels c + %s + %s + `, bq, connJoinQuery, pageQuery) + + q = applyOrdering(q, pm) + + dbPage, err := toDBChannelsPage(pm) + if err != nil { + return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + rows, err := repo.db.NamedQueryContext(ctx, q, dbPage) + if err != nil { + return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + defer rows.Close() + + var items []channels.Channel + for rows.Next() { + dbc := dbChannel{} + if err := rows.StructScan(&dbc); err != nil { + return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + c, err := toChannel(dbc) + if err != nil { + return channels.Page{}, err + } + + items = append(items, c) + } + + chJoinQuery := "" + if pm.Client != "" { + chJoinQuery = "JOIN connection conn ON conn.channel_id = c.id" + } + cq := fmt.Sprintf(`%s + SELECT COUNT(*) AS total_count + FROM ( + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.tags, + c.metadata, + c.created_by, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + c.parent_group_path, + c.role_id, + c.role_name, + c.actions, + c.access_type, + c.access_provider_id, + c.access_provider_role_id, + c.access_provider_role_name, + c.access_provider_role_actions + FROM + final_channels c + %s + %s + ) AS subquery; + `, bq, chJoinQuery, pageQuery) + + total, err := postgres.Total(ctx, repo.db, cq, dbPage) + if err != nil { + return channels.Page{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + page := channels.Page{ + Channels: items, + PageMetadata: channels.PageMetadata{ + Total: total, + Offset: pm.Offset, + Limit: pm.Limit, + }, + } + + return page, nil +} + +func (repo *channelRepository) userChannelsBaseQuery(domainID, userID string) string { + return fmt.Sprintf(` + WITH direct_channels AS ( + select + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.tags, + c.metadata, + c.created_by, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + text2ltree('') as parent_group_path, + cr.id AS role_id, + cr."name" AS role_name, + array_agg(cra."action") AS actions, + 'direct' as access_type, + '' AS access_provider_id, + '' AS access_provider_role_id, + '' AS access_provider_role_name, + array[]::::text[] AS access_provider_role_actions + FROM + channels_role_members crm + JOIN + channels_role_actions cra ON cra.role_id = crm.role_id + JOIN + channels_roles cr ON cr.id = crm.role_id + JOIN + channels c ON c.id = cr.entity_id + WHERE + crm.member_id = '%s' + AND c.domain_id = '%s' + GROUP BY + cr.entity_id, crm.member_id, cr.id, cr."name", c.id + ), + direct_groups AS ( + SELECT + g.*, + gr.entity_id AS entity_id, + grm.member_id AS member_id, + gr.id AS role_id, + gr."name" AS role_name, + array_agg(gra."action") AS actions + FROM + groups_role_members grm + JOIN + groups_role_actions gra ON gra.role_id = grm.role_id + JOIN + groups_roles gr ON gr.id = grm.role_id + JOIN + "groups" g ON g.id = gr.entity_id + WHERE + grm.member_id = '%s' + AND g.domain_id = '%s' + GROUP BY + gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id + ), + direct_groups_with_subgroup AS ( + SELECT + * + FROM direct_groups + WHERE EXISTS ( + SELECT 1 + FROM unnest(direct_groups.actions) AS action + WHERE action LIKE 'subgroup_%%' + ) + ), + indirect_child_groups AS ( + SELECT + DISTINCT indirect_child_groups.id as child_id, + indirect_child_groups.*, + dgws.id as access_provider_id, + dgws.role_id as access_provider_role_id, + dgws.role_name as access_provider_role_name, + dgws.actions as access_provider_role_actions + FROM + direct_groups_with_subgroup dgws + JOIN + groups indirect_child_groups ON indirect_child_groups.path <@ dgws.path + WHERE + indirect_child_groups.domain_id = '%s' + AND NOT EXISTS ( + SELECT 1 + FROM direct_groups_with_subgroup dgws + WHERE dgws.id = indirect_child_groups.id + ) + ), + final_groups AS ( + SELECT + id, + parent_id, + domain_id, + "name", + description, + metadata, + created_at, + updated_at, + updated_by, + status, + "path", + role_id, + role_name, + actions, + 'direct_group' AS access_type, + '' AS access_provider_id, + '' AS access_provider_role_id, + '' AS access_provider_role_name, + array[]::::text[] AS access_provider_role_actions + FROM + direct_groups + UNION + SELECT + id, + parent_id, + domain_id, + "name", + description, + metadata, + created_at, + updated_at, + updated_by, + status, + "path", + '' AS role_id, + '' AS role_name, + array[]::::text[] AS actions, + 'indirect_group' AS access_type, + access_provider_id, + access_provider_role_id, + access_provider_role_name, + access_provider_role_actions + FROM + indirect_child_groups + ), + final_channels AS ( + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.tags, + c.metadata, + c.created_by, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + g.path AS parent_group_path, + g.role_id, + g.role_name, + g.actions, + g.access_type, + g.access_provider_id, + g.access_provider_role_id, + g.access_provider_role_name, + g.access_provider_role_actions + FROM + final_groups g + JOIN + channels c ON c.parent_group_id = g.id + WHERE + c.id NOT IN (SELECT id FROM direct_channels) + UNION + SELECT * FROM direct_channels + ) + `, userID, domainID, userID, domainID, domainID) +} + func (cr *channelRepository) Remove(ctx context.Context, ids ...string) error { q := "DELETE FROM channels AS c WHERE c.id = ANY(:channel_ids) ;" params := map[string]interface{}{ @@ -361,7 +659,7 @@ func (cr *channelRepository) RemoveChannelConnections(ctx context.Context, chann func (cr *channelRepository) RetrieveParentGroupChannels(ctx context.Context, parentGroupID string) ([]channels.Channel, error) { query := `SELECT c.id, c.name, c.tags, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status, - c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c WHERE c.parent_group_id = :parent_group_id ;` + c.created_by, c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM channels c WHERE c.parent_group_id = :parent_group_id ;` rows, err := cr.db.NamedQueryContext(ctx, query, dbChannel{ParentGroup: toNullString(parentGroupID)}) if err != nil { @@ -420,17 +718,26 @@ func (cr *channelRepository) update(ctx context.Context, ch channels.Channel, qu } type dbChannel struct { - ID string `db:"id"` - Name string `db:"name,omitempty"` - ParentGroup sql.NullString `db:"parent_group_id,omitempty"` - Tags pgtype.TextArray `db:"tags,omitempty"` - Domain string `db:"domain_id"` - Metadata []byte `db:"metadata,omitempty"` - CreatedAt time.Time `db:"created_at,omitempty"` - UpdatedAt sql.NullTime `db:"updated_at,omitempty"` - UpdatedBy *string `db:"updated_by,omitempty"` - Status clients.Status `db:"status,omitempty"` - Role *clients.Role `db:"role,omitempty"` + ID string `db:"id"` + Name string `db:"name,omitempty"` + ParentGroup sql.NullString `db:"parent_group_id,omitempty"` + Tags pgtype.TextArray `db:"tags,omitempty"` + Domain string `db:"domain_id"` + Metadata []byte `db:"metadata,omitempty"` + CreatedBy *string `db:"created_by,omitempty"` + CreatedAt time.Time `db:"created_at,omitempty"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + UpdatedBy *string `db:"updated_by,omitempty"` + Status clients.Status `db:"status,omitempty"` + ParentGroupPath string `db:"parent_group_path,omitempty"` + RoleID string `db:"role_id,omitempty"` + RoleName string `db:"role_name,omitempty"` + Actions pq.StringArray `db:"actions,omitempty"` + AccessType string `db:"access_type,omitempty"` + AccessProviderId string `db:"access_provider_id,omitempty"` + AccessProviderRoleId string `db:"access_provider_role_id,omitempty"` + AccessProviderRoleName string `db:"access_provider_role_name,omitempty"` + AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions,omitempty"` } func toDBChannel(ch channels.Channel) (dbChannel, error) { @@ -446,6 +753,10 @@ func toDBChannel(ch channels.Channel) (dbChannel, error) { if err := tags.Set(ch.Tags); err != nil { return dbChannel{}, err } + var createdBy *string + if ch.CreatedBy != "" { + createdBy = &ch.CreatedBy + } var updatedBy *string if ch.UpdatedBy != "" { updatedBy = &ch.UpdatedBy @@ -461,6 +772,7 @@ func toDBChannel(ch channels.Channel) (dbChannel, error) { Domain: ch.Domain, Tags: tags, Metadata: data, + CreatedBy: createdBy, CreatedAt: ch.CreatedAt, UpdatedAt: updatedAt, UpdatedBy: updatedBy, @@ -497,6 +809,10 @@ func toChannel(ch dbChannel) (channels.Channel, error) { for _, e := range ch.Tags.Elements { tags = append(tags, e.String) } + var createdBy string + if ch.CreatedBy != nil { + createdBy = *ch.CreatedBy + } var updatedBy string if ch.UpdatedBy != nil { updatedBy = *ch.UpdatedBy @@ -507,16 +823,26 @@ func toChannel(ch dbChannel) (channels.Channel, error) { } newCh := channels.Channel{ - ID: ch.ID, - Name: ch.Name, - Tags: tags, - Domain: ch.Domain, - ParentGroup: toString(ch.ParentGroup), - Metadata: metadata, - CreatedAt: ch.CreatedAt, - UpdatedAt: updatedAt, - UpdatedBy: updatedBy, - Status: ch.Status, + ID: ch.ID, + Name: ch.Name, + Tags: tags, + Domain: ch.Domain, + ParentGroup: toString(ch.ParentGroup), + Metadata: metadata, + CreatedBy: createdBy, + CreatedAt: ch.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + Status: ch.Status, + ParentGroupPath: ch.ParentGroupPath, + RoleID: ch.RoleID, + RoleName: ch.RoleName, + Actions: ch.Actions, + AccessType: ch.AccessType, + AccessProviderId: ch.AccessProviderId, + AccessProviderRoleId: ch.AccessProviderRoleId, + AccessProviderRoleName: ch.AccessProviderRoleName, + AccessProviderRoleActions: ch.AccessProviderRoleActions, } return newCh, nil @@ -533,9 +859,6 @@ func PageQuery(pm channels.PageMetadata) (string, error) { query = append(query, "c.name ILIKE '%' || :name || '%'") } - if pm.ClientID != "" { - query = append(query, "conn.client_id = :client_id") - } if pm.Id != "" { query = append(query, "c.id ILIKE '%' || :id || '%'") } @@ -543,12 +866,6 @@ func PageQuery(pm channels.PageMetadata) (string, error) { query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')") } - // If there are search params presents, use search and ignore other options. - // Always combine role with search params, so len(query) > 1. - if len(query) > 1 { - return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil - } - if mq != "" { query = append(query, mq) } @@ -562,6 +879,31 @@ func PageQuery(pm channels.PageMetadata) (string, error) { if pm.Domain != "" { query = append(query, "c.domain_id = :domain_id") } + if pm.Group != "" { + query = append(query, "c.parent_group_path @> (SELECT path from groups where id = :group_id) ") + } + if pm.Client != "" { + query = append(query, "conn.client_id = :client_id ") + if pm.ConnectionType != "" { + query = append(query, "conn.type = :conn_type ") + } + } + if pm.AccessType != "" { + query = append(query, "c.access_type = :access_type") + } + if pm.RoleID != "" { + query = append(query, "c.role_id = :role_id") + } + if pm.RoleName != "" { + query = append(query, "c.role_name = :role_name") + } + if len(pm.Actions) != 0 { + query = append(query, "c.actions @> :actions") + } + if len(pm.Metadata) > 0 { + query = append(query, "c.metadata @> :metadata") + } + var emq string if len(query) > 0 { emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) @@ -586,28 +928,40 @@ func toDBChannelsPage(pm channels.PageMetadata) (dbChannelsPage, error) { return dbChannelsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } return dbChannelsPage{ - Name: pm.Name, - Id: pm.Id, - Metadata: data, - Domain: pm.Domain, - Total: pm.Total, - Offset: pm.Offset, - Limit: pm.Limit, - Status: pm.Status, - Tag: pm.Tag, + Limit: pm.Limit, + Offset: pm.Offset, + Name: pm.Name, + Id: pm.Id, + Domain: pm.Domain, + Metadata: data, + Tag: pm.Tag, + Status: pm.Status, + GroupID: pm.Group, + ClientID: pm.Client, + ConnType: pm.ConnectionType, + RoleName: pm.RoleName, + RoleID: pm.RoleID, + Actions: pm.Actions, + AccessType: pm.AccessType, }, nil } type dbChannelsPage struct { - Total uint64 `db:"total"` - Limit uint64 `db:"limit"` - Offset uint64 `db:"offset"` - Name string `db:"name"` - Id string `db:"id"` - Domain string `db:"domain_id"` - Metadata []byte `db:"metadata"` - Tag string `db:"tag"` - Status clients.Status `db:"status"` + Limit uint64 `db:"limit"` + Offset uint64 `db:"offset"` + Name string `db:"name"` + Id string `db:"id"` + Domain string `db:"domain_id"` + Metadata []byte `db:"metadata"` + Tag string `db:"tag"` + Status clients.Status `db:"status"` + GroupID string `db:"group_id"` + ClientID string `db:"client_id"` + ConnType string `db:"type"` + RoleName string `db:"role_name"` + RoleID string `db:"role_id"` + Actions pq.StringArray `db:"actions"` + AccessType string `db:"access_type"` } type dbConnection struct { diff --git a/channels/postgres/init.go b/channels/postgres/init.go index ce503435af..07e82f505c 100644 --- a/channels/postgres/init.go +++ b/channels/postgres/init.go @@ -4,6 +4,7 @@ package postgres import ( + gpostgres "github.com/absmach/supermq/groups/postgres" "github.com/absmach/supermq/pkg/errors" repoerr "github.com/absmach/supermq/pkg/errors/repository" rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" @@ -56,5 +57,13 @@ func Migration() (*migrate.MemoryMigrationSource, error) { }, } channelsMigration.Migrations = append(channelsMigration.Migrations, rolesMigration.Migrations...) + + groupsMigration, err := gpostgres.Migration() + if err != nil { + return &migrate.MemoryMigrationSource{}, err + } + + channelsMigration.Migrations = append(channelsMigration.Migrations, groupsMigration.Migrations...) + return channelsMigration, nil } diff --git a/channels/roleoperations.go b/channels/roleoperations.go index ae77b002b8..09c15340fe 100644 --- a/channels/roleoperations.go +++ b/channels/roleoperations.go @@ -141,7 +141,7 @@ const ( // External Permission // Domains. domainCreateChannelPermission = "channel_create_permission" - domainListChanelPermission = "list_channels_permission" + domainListChanelPermission = "channel_read_permission" // Groups. groupSetChildChannelPermission = "channel_create_permission" groupRemoveChildChannelPermission = "channel_create_permission" diff --git a/channels/service.go b/channels/service.go index bdcbc10577..4faed423b4 100644 --- a/channels/service.go +++ b/channels/service.go @@ -21,7 +21,6 @@ import ( svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" "github.com/absmach/supermq/pkg/roles" - "golang.org/x/sync/errgroup" ) var ( @@ -183,50 +182,30 @@ func (svc service) ViewChannel(ctx context.Context, session authn.Session, id st } func (svc service) ListChannels(ctx context.Context, session authn.Session, pm PageMetadata) (Page, error) { - var ids []string - var err error - switch session.SuperAdmin { case true: - pm.Domain = session.DomainID + cp, err := svc.repo.RetrieveAll(ctx, pm) + if err != nil { + return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) + } + return cp, nil default: - ids, err = svc.listChannelIDs(ctx, session.DomainUserID, pm.Permission) + cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, session.UserID, pm) if err != nil { - return Page{}, errors.Wrap(svcerr.ErrNotFound, err) + return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) } + return cp, nil } - if len(ids) == 0 && pm.Domain == "" { - return Page{}, nil - } - pm.IDs = ids +} - cp, err := svc.repo.RetrieveAll(ctx, pm) +func (svc service) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm PageMetadata) (Page, error) { + cp, err := svc.repo.RetrieveUserChannels(ctx, session.DomainID, userID, pm) if err != nil { return Page{}, errors.Wrap(svcerr.ErrViewEntity, err) } - - if pm.ListPerms && len(cp.Channels) > 0 { - g, ctx := errgroup.WithContext(ctx) - - for i := range cp.Channels { - // Copying loop variable "i" to avoid "loop variable captured by func literal" - iter := i - g.Go(func() error { - return svc.retrievePermissions(ctx, session.DomainUserID, &cp.Channels[iter]) - }) - } - - if err := g.Wait(); err != nil { - return Page{}, err - } - } return cp, nil } -func (svc service) ListChannelsByClient(ctx context.Context, session authn.Session, clID string, pm PageMetadata) (Page, error) { - return Page{}, nil -} - func (svc service) RemoveChannel(ctx context.Context, session authn.Session, id string) error { ok, err := svc.repo.DoesChannelHaveConnections(ctx, id) if err != nil { @@ -493,41 +472,6 @@ func (svc service) RemoveParentGroup(ctx context.Context, session authn.Session, return nil } -func (svc service) listChannelIDs(ctx context.Context, userID, permission string) ([]string, error) { - tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{ - SubjectType: policies.UserType, - Subject: userID, - Permission: permission, - ObjectType: policies.ChannelType, - }) - if err != nil { - return nil, errors.Wrap(svcerr.ErrNotFound, err) - } - return tids.Policies, nil -} - -func (svc service) retrievePermissions(ctx context.Context, userID string, channel *Channel) error { - permissions, err := svc.listUserClientPermission(ctx, userID, channel.ID) - if err != nil { - return err - } - channel.Permissions = permissions - return nil -} - -func (svc service) listUserClientPermission(ctx context.Context, userID, clientID string) ([]string, error) { - lp, err := svc.policy.ListPermissions(ctx, policies.Policy{ - SubjectType: policies.UserType, - Subject: userID, - Object: clientID, - ObjectType: policies.ChannelType, - }, []string{}) - if err != nil { - return []string{}, errors.Wrap(svcerr.ErrAuthorization, err) - } - return lp, nil -} - func (svc service) changeChannelStatus(ctx context.Context, userID string, channel Channel) (Channel, error) { dbchannel, err := svc.repo.RetrieveByID(ctx, channel.ID) if err != nil { diff --git a/channels/service_test.go b/channels/service_test.go index f4457d913a..d5dbf4974f 100644 --- a/channels/service_test.go +++ b/channels/service_test.go @@ -21,6 +21,7 @@ import ( gpmocks "github.com/absmach/supermq/groups/mocks" "github.com/absmach/supermq/internal/testsutil" "github.com/absmach/supermq/pkg/authn" + smqauthn "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/connections" "github.com/absmach/supermq/pkg/errors" repoerr "github.com/absmach/supermq/pkg/errors/repository" @@ -459,220 +460,180 @@ func TestDisableChannel(t *testing.T) { func TestListChannels(t *testing.T) { svc := newService(t) - channelWithPerms := validChannel - channelWithPerms.Permissions = []string{policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission} + adminID := testsutil.GenerateUUID(t) + domainID := testsutil.GenerateUUID(t) + nonAdminID := testsutil.GenerateUUID(t) cases := []struct { - desc string - session authn.Session - pageMeta channels.PageMetadata - listAllObjectsRes policysvc.PolicyPage - listAllObjectsErr error - retrieveAllRes channels.Page - retrieveAllErr error - listPermissionsRes policysvc.Permissions - listPermissionsErr error - resp channels.Page - err error + desc string + userKind string + session smqauthn.Session + page channels.PageMetadata + retrieveAllResponse channels.Page + response channels.Page + id string + size uint64 + listObjectsErr error + retrieveAllErr error + listPermissionsErr error + err error }{ { - desc: "list channesls as admin successfully", - session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true}, - pageMeta: channels.PageMetadata{ - Domain: validID, + desc: "list all channels successfully as non admin", + userKind: "non-admin", + session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, + id: nonAdminID, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, }, - retrieveAllRes: channels.Page{ - Channels: []channels.Channel{validChannel}, + retrieveAllResponse: channels.Page{ PageMetadata: channels.PageMetadata{ - Total: 1, + Total: 2, + Offset: 0, + Limit: 100, }, + Channels: []channels.Channel{validChannel, validChannel}, }, - resp: channels.Page{ - Channels: []channels.Channel{validChannel}, + response: channels.Page{ PageMetadata: channels.PageMetadata{ - Total: 1, + Total: 2, + Offset: 0, + Limit: 100, }, + Channels: []channels.Channel{validChannel, validChannel}, }, err: nil, }, { - desc: "list channels as admin with list perms successfully", - session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true}, - pageMeta: channels.PageMetadata{ - Domain: validID, - ListPerms: true, - }, - listPermissionsRes: policysvc.Permissions{ - policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission, - }, - retrieveAllRes: channels.Page{ - Channels: []channels.Channel{validChannel}, - PageMetadata: channels.PageMetadata{ - Total: 1, - }, - }, - resp: channels.Page{ - Channels: []channels.Channel{channelWithPerms}, - PageMetadata: channels.PageMetadata{ - Total: 1, - }, + desc: "list all channels as non admin with failed to retrieve all", + userKind: "non-admin", + session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, + id: nonAdminID, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, }, - err: nil, + retrieveAllResponse: channels.Page{}, + response: channels.Page{}, + retrieveAllErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, }, { - desc: "list channels as admin with failed to retrieve all", - session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true}, - pageMeta: channels.PageMetadata{ - Domain: validID, - }, - retrieveAllRes: channels.Page{}, - retrieveAllErr: repoerr.ErrNotFound, - err: repoerr.ErrNotFound, - }, - { - desc: "list channels as admin with failed to list permissions", - session: authn.Session{UserID: validID, DomainID: validID, DomainUserID: validID, SuperAdmin: true}, - pageMeta: channels.PageMetadata{ - Domain: validID, - ListPerms: true, + desc: "list all channels as non admin with failed super admin", + userKind: "non-admin", + session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, + id: nonAdminID, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, }, - retrieveAllRes: channels.Page{ - Channels: []channels.Channel{validChannel}, - PageMetadata: channels.PageMetadata{ - Total: 1, - }, - }, - listPermissionsRes: policysvc.Permissions{}, - listPermissionsErr: svcerr.ErrAuthorization, - err: svcerr.ErrAuthorization, - }, - { - desc: "list channels as admin with no domain id", - session: authn.Session{UserID: validID, SuperAdmin: true}, - pageMeta: channels.PageMetadata{}, + response: channels.Page{}, err: nil, }, { - desc: "list channels as user successfully", - session: validSession, - pageMeta: channels.PageMetadata{ - Permission: policysvc.ViewPermission, - IDs: []string{validChannel.ID}, - }, - listAllObjectsRes: policysvc.PolicyPage{ - Policies: []string{validChannel.ID}, + desc: "list all channels as non admin with failed to list objects", + userKind: "non-admin", + id: nonAdminID, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, }, - retrieveAllRes: channels.Page{ - Channels: []channels.Channel{validChannel}, - PageMetadata: channels.PageMetadata{ - Total: 1, - }, - }, - resp: channels.Page{ - Channels: []channels.Channel{validChannel}, - PageMetadata: channels.PageMetadata{ - Total: 1, - }, - }, - err: nil, - }, - { - desc: "list channels as user with failed to list all objects", - session: validSession, - pageMeta: channels.PageMetadata{ - Permission: policysvc.ViewPermission, - IDs: []string{validChannel.ID}, - }, - listAllObjectsErr: svcerr.ErrAuthorization, - err: svcerr.ErrAuthorization, + retrieveAllErr: repoerr.ErrNotFound, + response: channels.Page{}, + listObjectsErr: svcerr.ErrNotFound, + err: svcerr.ErrNotFound, }, + } + + for _, tc := range cases { + retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) + retrieveUserClientsCall := repo.On("RetrieveUserChannels", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) + page, err := svc.ListChannels(context.Background(), tc.session, tc.page) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) + retrieveAllCall.Unset() + retrieveUserClientsCall.Unset() + } + + cases2 := []struct { + desc string + userKind string + session smqauthn.Session + page channels.PageMetadata + retrieveAllResponse channels.Page + response channels.Page + id string + size uint64 + listObjectsErr error + retrieveAllErr error + listPermissionsErr error + err error + }{ { - desc: "list channels as user with list permissions successfully", - session: validSession, - pageMeta: channels.PageMetadata{ - Permission: policysvc.ViewPermission, - IDs: []string{validChannel.ID}, - ListPerms: true, - }, - listAllObjectsRes: policysvc.PolicyPage{ - Policies: []string{validChannel.ID}, + desc: "list all clients as admin successfully", + userKind: "admin", + id: adminID, + session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, + Domain: domainID, }, - retrieveAllRes: channels.Page{ - Channels: []channels.Channel{validChannel}, + retrieveAllResponse: channels.Page{ PageMetadata: channels.PageMetadata{ - Total: 1, + Total: 2, + Offset: 0, + Limit: 100, }, + Channels: []channels.Channel{validChannel, validChannel}, }, - listPermissionsRes: policysvc.Permissions{ - policysvc.AdminPermission, policysvc.EditPermission, policysvc.ViewPermission, - }, - resp: channels.Page{ - Channels: []channels.Channel{channelWithPerms}, + response: channels.Page{ PageMetadata: channels.PageMetadata{ - Total: 1, + Total: 2, + Offset: 0, + Limit: 100, }, + Channels: []channels.Channel{validChannel, validChannel}, }, err: nil, }, { - desc: "list channels as user with list permissions and failed to list permissions", - session: validSession, - pageMeta: channels.PageMetadata{ - Permission: policysvc.ViewPermission, - IDs: []string{validChannel.ID}, - ListPerms: true, + desc: "list all clients as admin with failed to retrieve all", + userKind: "admin", + id: adminID, + session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, + Domain: domainID, }, - listAllObjectsRes: policysvc.PolicyPage{ - Policies: []string{validChannel.ID}, - }, - retrieveAllRes: channels.Page{ - Channels: []channels.Channel{validChannel}, - PageMetadata: channels.PageMetadata{ - Total: 1, - }, - }, - listPermissionsRes: policysvc.Permissions{}, - listPermissionsErr: svcerr.ErrAuthorization, - err: svcerr.ErrAuthorization, + retrieveAllResponse: channels.Page{}, + retrieveAllErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, }, { - desc: "list channels as user with failed to retrieve all", - session: validSession, - pageMeta: channels.PageMetadata{ - Permission: policysvc.ViewPermission, - IDs: []string{validChannel.ID}, + desc: "list all clients as admin with failed to list clients", + userKind: "admin", + id: adminID, + session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, + page: channels.PageMetadata{ + Offset: 0, + Limit: 100, + Domain: domainID, }, - listAllObjectsRes: policysvc.PolicyPage{ - Policies: []string{validChannel.ID}, - }, - retrieveAllRes: channels.Page{}, - retrieveAllErr: repoerr.ErrNotFound, - err: repoerr.ErrNotFound, + retrieveAllResponse: channels.Page{}, + retrieveAllErr: repoerr.ErrNotFound, + err: svcerr.ErrNotFound, }, } - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - policyCall := policies.On("ListAllObjects", context.Background(), policysvc.Policy{ - SubjectType: policysvc.UserType, - Subject: validID, - Permission: policysvc.ViewPermission, - ObjectType: policysvc.ChannelType, - }).Return(tc.listAllObjectsRes, tc.listAllObjectsErr) - repoCall := repo.On("RetrieveAll", context.Background(), tc.pageMeta).Return(tc.retrieveAllRes, tc.retrieveAllErr) - policyCall1 := policies.On("ListPermissions", mock.Anything, policysvc.Policy{ - SubjectType: policysvc.UserType, - Subject: validID, - Object: validChannel.ID, - ObjectType: policysvc.ChannelType, - }, []string{}).Return(tc.listPermissionsRes, tc.listPermissionsErr) - got, err := svc.ListChannels(context.Background(), tc.session, tc.pageMeta) - assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("expected error %v to contain %v", err, tc.err)) - assert.Equal(t, tc.resp, got) - policyCall.Unset() - repoCall.Unset() - policyCall1.Unset() - }) + for _, tc := range cases2 { + retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) + page, err := svc.ListChannels(context.Background(), tc.session, tc.page) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) + retrieveAllCall.Unset() } } diff --git a/channels/tracing/tracing.go b/channels/tracing/tracing.go index 768ceb722b..b5114261f5 100644 --- a/channels/tracing/tracing.go +++ b/channels/tracing/tracing.go @@ -50,10 +50,10 @@ func (tm *tracingMiddleware) ListChannels(ctx context.Context, session authn.Ses return tm.svc.ListChannels(ctx, session, pm) } -func (tm *tracingMiddleware) ListChannelsByClient(ctx context.Context, session authn.Session, clientID string, pm channels.PageMetadata) (channels.Page, error) { - ctx, span := tm.tracer.Start(ctx, "svc_list_channels") +func (tm *tracingMiddleware) ListUserChannels(ctx context.Context, session authn.Session, userID string, pm channels.PageMetadata) (channels.Page, error) { + ctx, span := tm.tracer.Start(ctx, "svc_list_user_channels") defer span.End() - return tm.svc.ListChannelsByClient(ctx, session, clientID, pm) + return tm.svc.ListUserChannels(ctx, session, userID, pm) } // UpdateChannel traces the "UpdateChannel" operation of the wrapped policies.Service. diff --git a/cli/users.go b/cli/users.go index 3772ebefd2..efbe355e12 100644 --- a/cli/users.go +++ b/cli/users.go @@ -370,32 +370,6 @@ var cmdUsers = []cobra.Command{ logOKCmd(*cmd) }, }, - { - Use: "clients ", - Short: "List clients", - Long: "List clients of user\n" + - "Usage:\n" + - "\tsupermq-cli users clients \n", - Run: func(cmd *cobra.Command, args []string) { - if len(args) != 3 { - logUsageCmd(*cmd, cmd.Use) - return - } - - pm := smqsdk.PageMetadata{ - Offset: Offset, - Limit: Limit, - } - - tp, err := sdk.ListUserClients(args[0], args[1], pm, args[2]) - if err != nil { - logErrorCmd(*cmd, err) - return - } - - logJSONCmd(*cmd, tp) - }, - }, { Use: "search ", Short: "Search users", diff --git a/clients/api/http/decode.go b/clients/api/http/decode.go index 4c2377bbb6..71559d7e17 100644 --- a/clients/api/http/decode.go +++ b/clients/api/http/decode.go @@ -27,58 +27,112 @@ func decodeViewClient(_ context.Context, r *http.Request) (interface{}, error) { } func decodeListClients(_ context.Context, r *http.Request) (interface{}, error) { - s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefClientStatus) + name, err := apiutil.ReadStringQuery(r, api.NameKey, "") if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + + tag, err := apiutil.ReadStringQuery(r, api.TagKey, "") if err != nil { return nil, errors.Wrap(apiutil.ErrValidation, err) } - l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + + s, err := apiutil.ReadStringQuery(r, api.StatusKey, api.DefGroupStatus) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - m, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil) + status, err := clients.ToStatus(s) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - n, err := apiutil.ReadStringQuery(r, api.NameKey, "") + + meta, err := apiutil.ReadMetadataQuery(r, api.MetadataKey, nil) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - t, err := apiutil.ReadStringQuery(r, api.TagKey, "") + + offset, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - id, err := apiutil.ReadStringQuery(r, api.IDOrder, "") + limit, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - p, err := apiutil.ReadStringQuery(r, api.PermissionKey, api.DefPermission) + + dir, err := apiutil.ReadStringQuery(r, api.DirKey, api.DefDir) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - lp, err := apiutil.ReadBoolQuery(r, api.ListPerms, api.DefListPerms) + order, err := apiutil.ReadStringQuery(r, api.OrderKey, api.DefOrder) if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } - st, err := clients.ToStatus(s) + + allActions, err := apiutil.ReadStringQuery(r, api.ActionsKey, "") if err != nil { - return nil, errors.Wrap(apiutil.ErrValidation, err) + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + actions := []string{} + + allActions = strings.TrimSpace(allActions) + if allActions != "" { + actions = strings.Split(allActions, ",") + } + roleID, err := apiutil.ReadStringQuery(r, api.RoleIDKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) } + + roleName, err := apiutil.ReadStringQuery(r, api.RoleNameKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + accessType, err := apiutil.ReadStringQuery(r, api.AccessTypeKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + userID, err := apiutil.ReadStringQuery(r, api.UserKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + groupID, err := apiutil.ReadStringQuery(r, api.GroupKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + channelID, err := apiutil.ReadStringQuery(r, api.ChannelKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + + connType, err := apiutil.ReadStringQuery(r, api.ConnTypeKey, "") + if err != nil { + return listClientsReq{}, errors.Wrap(apiutil.ErrValidation, err) + } + req := listClientsReq{ - status: st, - offset: o, - limit: l, - metadata: m, - name: n, - tag: t, - permission: p, - listPerms: lp, - userID: chi.URLParam(r, "userID"), - id: id, + name: name, + tag: tag, + status: status, + metadata: meta, + roleName: roleName, + roleID: roleID, + actions: actions, + accessType: accessType, + order: order, + dir: dir, + offset: offset, + limit: limit, + groupID: groupID, + channelID: channelID, + connType: connType, + userID: userID, } return req, nil } diff --git a/clients/api/http/endpoints.go b/clients/api/http/endpoints.go index 575e41949c..25cbc25f0d 100644 --- a/clients/api/http/endpoints.go +++ b/clients/api/http/endpoints.go @@ -104,19 +104,33 @@ func listClientsEndpoint(svc clients.Service) endpoint.Endpoint { } pm := clients.Page{ - Status: req.status, - Offset: req.offset, - Limit: req.limit, - Name: req.name, - Tag: req.tag, - Permission: req.permission, - Metadata: req.metadata, - ListPerms: req.listPerms, - Id: req.id, - } - page, err := svc.ListClients(ctx, session, req.userID, pm) + Name: req.name, + Tag: req.tag, + Status: req.status, + Metadata: req.metadata, + RoleName: req.roleName, + RoleID: req.roleID, + Actions: req.actions, + AccessType: req.accessType, + Order: req.order, + Dir: req.dir, + Offset: req.offset, + Limit: req.limit, + Group: req.groupID, + Channel: req.channelID, + ConnectionType: req.connType, + } + + var page clients.ClientsPage + var err error + switch req.userID != "" { + case true: + page, err = svc.ListUserClients(ctx, session, req.userID, pm) + default: + page, err = svc.ListClients(ctx, session, pm) + } if err != nil { - return nil, err + return clientsPageRes{}, err } res := clientsPageRes{ diff --git a/clients/api/http/endpoints_test.go b/clients/api/http/endpoints_test.go index 2ef69eb636..8104737cc3 100644 --- a/clients/api/http/endpoints_test.go +++ b/clients/api/http/endpoints_test.go @@ -743,7 +743,7 @@ func TestListClients(t *testing.T) { } authCall := authn.On("Authenticate", mock.Anything, tc.token).Return(tc.authnRes, tc.authnErr) - svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, "", mock.Anything).Return(tc.listClientsResponse, tc.err) + svcCall := svc.On("ListClients", mock.Anything, tc.authnRes, mock.Anything).Return(tc.listClientsResponse, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) diff --git a/clients/api/http/requests.go b/clients/api/http/requests.go index ae2e586f21..9e686b0a4c 100644 --- a/clients/api/http/requests.go +++ b/clients/api/http/requests.go @@ -71,29 +71,29 @@ func (req viewClientPermsReq) validate() error { } type listClientsReq struct { + name string + tag string status clients.Status + metadata clients.Metadata + roleName string + roleID string + actions []string + accessType string + order string + dir string offset uint64 limit uint64 - name string - tag string - permission string - visibility string + groupID string + channelID string + connType string userID string - listPerms bool - metadata clients.Metadata - id string } func (req listClientsReq) validate() error { if req.limit > api.MaxLimitSize || req.limit < 1 { return apiutil.ErrLimitSize } - if req.visibility != "" && - req.visibility != api.AllVisibility && - req.visibility != api.MyVisibility && - req.visibility != api.SharedVisibility { - return apiutil.ErrInvalidVisibilityType - } + if len(req.name) > api.MaxNameSize { return apiutil.ErrNameSize } diff --git a/clients/api/http/requests_test.go b/clients/api/http/requests_test.go index cc9fcc0d5a..7d110fd506 100644 --- a/clients/api/http/requests_test.go +++ b/clients/api/http/requests_test.go @@ -210,14 +210,6 @@ func TestListClientsReqValidate(t *testing.T) { }, err: apiutil.ErrLimitSize, }, - { - desc: "invalid visibility", - req: listClientsReq{ - limit: 10, - visibility: "invalid", - }, - err: apiutil.ErrInvalidVisibilityType, - }, { desc: "name too long", req: listClientsReq{ diff --git a/clients/clients.go b/clients/clients.go index a4194cd90d..5c09f21e36 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -13,6 +13,10 @@ import ( "github.com/absmach/supermq/pkg/roles" ) +type CtxKey int + +const ListDomainClients CtxKey = iota + type Connection struct { ClientID string ChannelID string @@ -35,11 +39,14 @@ type Repository interface { // RetrieveAll retrieves all clients. RetrieveAll(ctx context.Context, pm Page) (ClientsPage, error) + // RetrieveUserClients retrieve all clients of a given user id. + RetrieveUserClients(ctx context.Context, domainID, userID string, pm Page) (ClientsPage, error) + // SearchClients retrieves clients based on search criteria. SearchClients(ctx context.Context, pm Page) (ClientsPage, error) - // RetrieveAllByIDs retrieves for given client IDs . - RetrieveAllByIDs(ctx context.Context, pm Page) (ClientsPage, error) + // RetrieveByIds + RetrieveByIds(ctx context.Context, ids []string) (ClientsPage, error) // Update updates the client name and metadata. Update(ctx context.Context, client Client) (Client, error) @@ -66,8 +73,6 @@ type Repository interface { // RetrieveBySecret retrieves a client based on the secret (key). RetrieveBySecret(ctx context.Context, key string) (Client, error) - RetrieveByIds(ctx context.Context, ids []string) (ClientsPage, error) - AddConnections(ctx context.Context, conns []Connection) error RemoveConnections(ctx context.Context, conns []Connection) error @@ -105,8 +110,11 @@ type Service interface { // View retrieves client info for a given client ID and an authorized token. View(ctx context.Context, session authn.Session, id string) (Client, error) - // ListClients retrieves clients list for a valid auth token. - ListClients(ctx context.Context, session authn.Session, reqUserID string, pm Page) (ClientsPage, error) + // ListClients retrieves clients list for given page query. + ListClients(ctx context.Context, session authn.Session, pm Page) (ClientsPage, error) + + // ListUserClients retrieves clients list for a given user id and page query. + ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error) // Update updates the client's name and metadata. Update(ctx context.Context, session authn.Session, client Client) (Client, error) @@ -161,8 +169,17 @@ type Client struct { UpdatedAt time.Time `json:"updated_at,omitempty"` UpdatedBy string `json:"updated_by,omitempty"` Status Status `json:"status,omitempty"` // 1 for enabled, 0 for disabled - Permissions []string `json:"permissions,omitempty"` Identity string `json:"identity,omitempty"` + // Extended + ParentGroupPath string `json:"parent_group_path,omitempty"` + RoleID string `json:"role_id,omitempty"` + RoleName string `json:"role_name,omitempty"` + Actions []string `json:"actions,omitempty"` + AccessType string `json:"access_type,omitempty"` + AccessProviderId string `json:"access_provider_id,omitempty"` + AccessProviderRoleId string `json:"access_provider_role_id,omitempty"` + AccessProviderRoleName string `json:"access_provider_role_name,omitempty"` + AccessProviderRoleActions []string `json:"access_provider_role_actions,omitempty"` } // ClientsPage contains page related metadata as well as list. @@ -182,21 +199,26 @@ type MembersPage struct { // Page contains the page metadata that helps navigation. type Page struct { - Total uint64 `json:"total"` - Offset uint64 `json:"offset"` - Limit uint64 `json:"limit"` - Name string `json:"name,omitempty"` - Id string `json:"id,omitempty"` - Order string `json:"order,omitempty"` - Dir string `json:"dir,omitempty"` - Metadata Metadata `json:"metadata,omitempty"` - Domain string `json:"domain,omitempty"` - Tag string `json:"tag,omitempty"` - Permission string `json:"permission,omitempty"` - Status Status `json:"status,omitempty"` - IDs []string `json:"ids,omitempty"` - Identity string `json:"identity,omitempty"` - ListPerms bool `json:"-"` + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Order string `json:"order,omitempty"` + Dir string `json:"dir,omitempty"` + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Metadata Metadata `json:"metadata,omitempty"` + Domain string `json:"domain,omitempty"` + Tag string `json:"tag,omitempty"` + Status Status `json:"status,omitempty"` + Identity string `json:"identity,omitempty"` + Group string `json:"group,omitempty"` + Channel string `json:"channel,omitempty"` + ConnectionType string `json:"connection_type,omitempty"` + RoleName string `json:"role_name,omitempty"` + RoleID string `json:"role_id,omitempty"` + Actions []string `json:"actions,omitempty"` + AccessType string `json:"access_type,omitempty"` + IDs []string `json:"-"` } // Metadata represents arbitrary JSON. diff --git a/clients/events/events.go b/clients/events/events.go index 7d42ef35d1..0405dffb93 100644 --- a/clients/events/events.go +++ b/clients/events/events.go @@ -205,7 +205,6 @@ func (vcpe viewClientPermsEvent) Encode() (map[string]interface{}, error) { } type listClientEvent struct { - reqUserID string clients.Page authn.Session } @@ -213,7 +212,6 @@ type listClientEvent struct { func (lce listClientEvent) Encode() (map[string]interface{}, error) { val := map[string]interface{}{ "operation": clientList, - "reqUserID": lce.reqUserID, "total": lce.Total, "offset": lce.Offset, "limit": lce.Limit, @@ -238,8 +236,51 @@ func (lce listClientEvent) Encode() (map[string]interface{}, error) { if lce.Tag != "" { val["tag"] = lce.Tag } - if lce.Permission != "" { - val["permission"] = lce.Permission + if lce.Status.String() != "" { + val["status"] = lce.Status.String() + } + if len(lce.IDs) > 0 { + val["ids"] = lce.IDs + } + if lce.Identity != "" { + val["identity"] = lce.Identity + } + return val, nil +} + +type listUserClientEvent struct { + userID string + clients.Page + authn.Session +} + +func (lce listUserClientEvent) Encode() (map[string]interface{}, error) { + val := map[string]interface{}{ + "operation": clientList, + "req_user_id": lce.userID, + "total": lce.Total, + "offset": lce.Offset, + "limit": lce.Limit, + "domain": lce.DomainID, + "user_id": lce.UserID, + "token_type": lce.Type.String(), + "super_admin": lce.SuperAdmin, + } + + if lce.Name != "" { + val["name"] = lce.Name + } + if lce.Order != "" { + val["order"] = lce.Order + } + if lce.Dir != "" { + val["dir"] = lce.Dir + } + if lce.Metadata != nil { + val["metadata"] = lce.Metadata + } + if lce.Tag != "" { + val["tag"] = lce.Tag } if lce.Status.String() != "" { val["status"] = lce.Status.String() @@ -288,9 +329,6 @@ func (lcge listClientByGroupEvent) Encode() (map[string]interface{}, error) { if lcge.Tag != "" { val["tag"] = lcge.Tag } - if lcge.Permission != "" { - val["permission"] = lcge.Permission - } if lcge.Status.String() != "" { val["status"] = lcge.Status.String() } diff --git a/clients/events/streams.go b/clients/events/streams.go index 8e13c2a3c7..1473c04c0d 100644 --- a/clients/events/streams.go +++ b/clients/events/streams.go @@ -118,15 +118,31 @@ func (es *eventStore) View(ctx context.Context, session authn.Session, id string return cli, nil } -func (es *eventStore) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) { - cp, err := es.svc.ListClients(ctx, session, reqUserID, pm) +func (es *eventStore) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) { + cp, err := es.svc.ListClients(ctx, session, pm) if err != nil { return cp, err } event := listClientEvent{ - reqUserID: reqUserID, - Page: pm, - Session: session, + pm, + session, + } + if err := es.Publish(ctx, event); err != nil { + return cp, err + } + + return cp, nil +} + +func (es *eventStore) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) { + cp, err := es.svc.ListUserClients(ctx, session, userID, pm) + if err != nil { + return cp, err + } + event := listUserClientEvent{ + userID, + pm, + session, } if err := es.Publish(ctx, event); err != nil { return cp, err diff --git a/clients/middleware/authorization.go b/clients/middleware/authorization.go index 7130360d8e..8a3483b601 100644 --- a/clients/middleware/authorization.go +++ b/clients/middleware/authorization.go @@ -129,7 +129,7 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi return am.svc.View(ctx, session, id) } -func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) { +func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) { if session.Type == authn.PersonalAccessToken { if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ UserID: session.UserID, @@ -144,11 +144,33 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth } } - if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { + if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } - return am.svc.ListClients(ctx, session, reqUserID, pm) + return am.svc.ListClients(ctx, session, pm) +} + +func (am *authorizationMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) { + if session.Type == authn.PersonalAccessToken { + if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ + UserID: session.UserID, + PatID: session.PatID, + PlatformEntityType: auth.PlatformDomainsScope, + OptionalDomainID: session.DomainID, + OptionalDomainEntityType: auth.DomainClientsScope, + Operation: auth.ListOp, + EntityIDs: auth.AnyIDs{}.Values(), + }); err != nil { + return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + } + + if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { + return clients.ClientsPage{}, err + } + + return am.svc.ListUserClients(ctx, session, userID, pm) } func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { diff --git a/clients/middleware/logging.go b/clients/middleware/logging.go index 378409b41d..cc3eca477d 100644 --- a/clients/middleware/logging.go +++ b/clients/middleware/logging.go @@ -65,11 +65,10 @@ func (lm *loggingMiddleware) View(ctx context.Context, session authn.Session, id return lm.svc.View(ctx, session, id) } -func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (cp clients.ClientsPage, err error) { +func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (cp clients.ClientsPage, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), - slog.String("user_id", reqUserID), slog.Group("page", slog.Uint64("limit", pm.Limit), slog.Uint64("offset", pm.Offset), @@ -83,7 +82,28 @@ func (lm *loggingMiddleware) ListClients(ctx context.Context, session authn.Sess } lm.logger.Info("List clients completed successfully", args...) }(time.Now()) - return lm.svc.ListClients(ctx, session, reqUserID, pm) + return lm.svc.ListClients(ctx, session, pm) +} + +func (lm *loggingMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (cp clients.ClientsPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("user_id", userID), + slog.Group("page", + slog.Uint64("limit", pm.Limit), + slog.Uint64("offset", pm.Offset), + slog.Uint64("total", cp.Total), + ), + } + if err != nil { + args = append(args, slog.String("error", err.Error())) + lm.logger.Warn("List clients failed", args...) + return + } + lm.logger.Info("List clients completed successfully", args...) + }(time.Now()) + return lm.svc.ListUserClients(ctx, session, userID, pm) } func (lm *loggingMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (c clients.Client, err error) { diff --git a/clients/middleware/metrics.go b/clients/middleware/metrics.go index ca79ae903f..12a4b1d805 100644 --- a/clients/middleware/metrics.go +++ b/clients/middleware/metrics.go @@ -49,12 +49,20 @@ func (ms *metricsMiddleware) View(ctx context.Context, session authn.Session, id return ms.svc.View(ctx, session, id) } -func (ms *metricsMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) { +func (ms *metricsMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) { defer func(begin time.Time) { ms.counter.With("method", "list_clients").Add(1) ms.latency.With("method", "list_clients").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.ListClients(ctx, session, reqUserID, pm) + return ms.svc.ListClients(ctx, session, pm) +} + +func (ms *metricsMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_user_clients").Add(1) + ms.latency.With("method", "list_user_clients").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ListUserClients(ctx, session, userID, pm) } func (ms *metricsMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { diff --git a/clients/mocks/repository.go b/clients/mocks/repository.go index 873aadb1fc..fc45044a2e 100644 --- a/clients/mocks/repository.go +++ b/clients/mocks/repository.go @@ -312,34 +312,6 @@ func (_m *Repository) RetrieveAll(ctx context.Context, pm clients.Page) (clients return r0, r1 } -// RetrieveAllByIDs provides a mock function with given fields: ctx, pm -func (_m *Repository) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) { - ret := _m.Called(ctx, pm) - - if len(ret) == 0 { - panic("no return value specified for RetrieveAllByIDs") - } - - var r0 clients.ClientsPage - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, clients.Page) (clients.ClientsPage, error)); ok { - return rf(ctx, pm) - } - if rf, ok := ret.Get(0).(func(context.Context, clients.Page) clients.ClientsPage); ok { - r0 = rf(ctx, pm) - } else { - r0 = ret.Get(0).(clients.ClientsPage) - } - - if rf, ok := ret.Get(1).(func(context.Context, clients.Page) error); ok { - r1 = rf(ctx, pm) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - // RetrieveAllRoles provides a mock function with given fields: ctx, entityID, limit, offset func (_m *Repository) RetrieveAllRoles(ctx context.Context, entityID string, limit uint64, offset uint64) (roles.RolePage, error) { ret := _m.Called(ctx, entityID, limit, offset) @@ -577,6 +549,34 @@ func (_m *Repository) RetrieveRole(ctx context.Context, roleID string) (roles.Ro return r0, r1 } +// RetrieveUserClients provides a mock function with given fields: ctx, domainID, userID, pm +func (_m *Repository) RetrieveUserClients(ctx context.Context, domainID string, userID string, pm clients.Page) (clients.ClientsPage, error) { + ret := _m.Called(ctx, domainID, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for RetrieveUserClients") + } + + var r0 clients.ClientsPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, clients.Page) (clients.ClientsPage, error)); ok { + return rf(ctx, domainID, userID, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, clients.Page) clients.ClientsPage); ok { + r0 = rf(ctx, domainID, userID, pm) + } else { + r0 = ret.Get(0).(clients.ClientsPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, clients.Page) error); ok { + r1 = rf(ctx, domainID, userID, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RoleAddActions provides a mock function with given fields: ctx, role, actions func (_m *Repository) RoleAddActions(ctx context.Context, role roles.Role, actions []string) ([]string, error) { ret := _m.Called(ctx, role, actions) diff --git a/clients/mocks/service.go b/clients/mocks/service.go index 08baa914ce..3d64f162cd 100644 --- a/clients/mocks/service.go +++ b/clients/mocks/service.go @@ -198,27 +198,55 @@ func (_m *Service) ListAvailableActions(ctx context.Context, session authn.Sessi return r0, r1 } -// ListClients provides a mock function with given fields: ctx, session, reqUserID, pm -func (_m *Service) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) { - ret := _m.Called(ctx, session, reqUserID, pm) +// ListClients provides a mock function with given fields: ctx, session, pm +func (_m *Service) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) { + ret := _m.Called(ctx, session, pm) if len(ret) == 0 { panic("no return value specified for ListClients") } + var r0 clients.ClientsPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, clients.Page) (clients.ClientsPage, error)); ok { + return rf(ctx, session, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, clients.Page) clients.ClientsPage); ok { + r0 = rf(ctx, session, pm) + } else { + r0 = ret.Get(0).(clients.ClientsPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, clients.Page) error); ok { + r1 = rf(ctx, session, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListUserClients provides a mock function with given fields: ctx, session, userID, pm +func (_m *Service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) { + ret := _m.Called(ctx, session, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for ListUserClients") + } + var r0 clients.ClientsPage var r1 error if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, clients.Page) (clients.ClientsPage, error)); ok { - return rf(ctx, session, reqUserID, pm) + return rf(ctx, session, userID, pm) } if rf, ok := ret.Get(0).(func(context.Context, authn.Session, string, clients.Page) clients.ClientsPage); ok { - r0 = rf(ctx, session, reqUserID, pm) + r0 = rf(ctx, session, userID, pm) } else { r0 = ret.Get(0).(clients.ClientsPage) } if rf, ok := ret.Get(1).(func(context.Context, authn.Session, string, clients.Page) error); ok { - r1 = rf(ctx, session, reqUserID, pm) + r1 = rf(ctx, session, userID, pm) } else { r1 = ret.Error(1) } diff --git a/clients/postgres/clients.go b/clients/postgres/clients.go index 62717f9edf..3fa6719312 100644 --- a/clients/postgres/clients.go +++ b/clients/postgres/clients.go @@ -20,6 +20,7 @@ import ( "github.com/absmach/supermq/pkg/postgres" rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" "github.com/jackc/pgtype" + "github.com/lib/pq" ) const ( @@ -247,25 +248,57 @@ func (repo *clientRepo) RetrieveAll(ctx context.Context, pm clients.Page) (clien return page, nil } -func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) { - query, err := PageQuery(pm) - if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) - } - - tq := query - query = applyOrdering(query, pm) +func (repo *clientRepo) RetrieveUserClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) { + return repo.retrieveClients(ctx, domainID, userID, pm) +} - q := fmt.Sprintf(`SELECT c.id, c.name, c.created_at, c.updated_at FROM clients c %s LIMIT :limit OFFSET :offset;`, query) +func (repo *clientRepo) retrieveClients(ctx context.Context, domainID, userID string, pm clients.Page) (clients.ClientsPage, error) { + pageQuery, err := PageQuery(pm) + if err != nil { + return clients.ClientsPage{}, err + } + + bq := repo.userClientBaseQuery(domainID, userID) + + q := fmt.Sprintf(` + %s + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.identity, + c.secret, + c.tags, + c.metadata, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + c.parent_group_path, + c.role_id, + c.role_name, + c.actions, + c.access_type, + c.access_provider_id, + c.access_provider_role_id, + c.access_provider_role_name, + c.access_provider_role_actions + FROM + final_clients c + %s + `, bq, pageQuery) + + q = applyOrdering(q, pm) dbPage, err := ToDBClientsPage(pm) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) if err != nil { - return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) + return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } defer rows.Close() @@ -284,7 +317,42 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli items = append(items, c) } - cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, tq) + connJoinQuery := "" + if pm.Channel != "" { + connJoinQuery = "JOIN connection conn ON conn.client_id = c.id" + } + cq := fmt.Sprintf(`%s + SELECT COUNT(*) AS total_count + FROM ( + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.identity, + c.secret, + c.tags, + c.metadata, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + c.parent_group_path, + c.role_id, + c.role_name, + c.actions, + c.access_type, + c.access_provider_id, + c.access_provider_role_id, + c.access_provider_role_name, + c.access_provider_role_actions + FROM + final_clients c + %s + %s + ) AS subquery; + `, bq, connJoinQuery, pageQuery) + total, err := postgres.Total(ctx, repo.DB, cq, dbPage) if err != nil { return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) @@ -302,25 +370,244 @@ func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (cli return page, nil } -func (repo *clientRepo) RetrieveAllByIDs(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) { - if (len(pm.IDs) == 0) && (pm.Domain == "") { - return clients.ClientsPage{ - Page: clients.Page{Total: pm.Total, Offset: pm.Offset, Limit: pm.Limit}, - }, nil - } +func (repo *clientRepo) userClientBaseQuery(domainID, userID string) string { + return fmt.Sprintf(` + WITH direct_clients AS ( + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.identity, + c.secret, + c.tags, + c.metadata, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + text2ltree('') as parent_group_path, + cr.id AS role_id, + cr."name" AS role_name, + array_agg(cra."action") AS actions, + 'direct' as access_type, + '' AS access_provider_id, + '' AS access_provider_role_id, + '' AS access_provider_role_name, + array[]::::text[] AS access_provider_role_actions + FROM + clients_role_members crm + JOIN + clients_role_actions cra ON cra.role_id = crm.role_id + JOIN + clients_roles cr ON cr.id = crm.role_id + JOIN + clients c ON c.id = cr.entity_id + WHERE + crm.member_id = '%s' + AND c.domain_id = '%s' + GROUP BY + cr.entity_id, crm.member_id, cr.id, cr."name", c.id + ), + direct_groups AS ( + SELECT + g.*, + gr.entity_id AS entity_id, + grm.member_id AS member_id, + gr.id AS role_id, + gr."name" AS role_name, + array_agg(gra."action") AS actions + FROM + groups_role_members grm + JOIN + groups_role_actions gra ON gra.role_id = grm.role_id + JOIN + groups_roles gr ON gr.id = grm.role_id + JOIN + "groups" g ON g.id = gr.entity_id + WHERE + grm.member_id = '%s' + AND g.domain_id = '%s' + GROUP BY + gr.entity_id, grm.member_id, gr.id, gr."name", g."path", g.id + ), + direct_groups_with_subgroup AS ( + SELECT + * + FROM direct_groups + WHERE EXISTS ( + SELECT 1 + FROM unnest(direct_groups.actions) AS action + WHERE action LIKE 'subgroup_%%' + ) + ), + indirect_child_groups AS ( + SELECT + DISTINCT indirect_child_groups.id as child_id, + indirect_child_groups.*, + dgws.id as access_provider_id, + dgws.role_id as access_provider_role_id, + dgws.role_name as access_provider_role_name, + dgws.actions as access_provider_role_actions + FROM + direct_groups_with_subgroup dgws + JOIN + groups indirect_child_groups ON indirect_child_groups.path <@ dgws.path + WHERE + indirect_child_groups.domain_id = '%s' + AND NOT EXISTS ( + SELECT 1 + FROM direct_groups_with_subgroup dgws + WHERE dgws.id = indirect_child_groups.id + ) + ), + final_groups AS ( + SELECT + id, + parent_id, + domain_id, + "name", + description, + metadata, + created_at, + updated_at, + updated_by, + status, + "path", + role_id, + role_name, + actions, + 'direct_group' AS access_type, + '' AS access_provider_id, + '' AS access_provider_role_id, + '' AS access_provider_role_name, + array[]::::text[] AS access_provider_role_actions + FROM + direct_groups + UNION + SELECT + id, + parent_id, + domain_id, + "name", + description, + metadata, + created_at, + updated_at, + updated_by, + status, + "path", + '' AS role_id, + '' AS role_name, + array[]::::text[] AS actions, + 'indirect_group' AS access_type, + access_provider_id, + access_provider_role_id, + access_provider_role_name, + access_provider_role_actions + FROM + indirect_child_groups + ), + group_direct_clients AS ( + SELECT + c.id, + c.name, + c.domain_id, + c.parent_group_id, + c.identity, + c.secret, + c.tags, + c.metadata, + c.created_at, + c.updated_at, + c.updated_by, + c.status, + g.path AS parent_group_path, + g.role_id, + g.role_name, + g.actions, + g.access_type, + g.access_provider_id, + g.access_provider_role_id, + g.access_provider_role_name, + g.access_provider_role_actions + FROM + final_groups g + JOIN + clients c ON c.parent_group_id = g.id + WHERE + c.id NOT IN (SELECT id FROM direct_clients) + UNION + SELECT + dc.id, + dc.name, + dc.domain_id, + dc.parent_group_id, + dc.identity, + dc.secret, + dc.tags, + dc.metadata, + dc.created_at, + dc.updated_at, + dc.updated_by, + dc.status, + dc.parent_group_path, + dc.role_id, + dc.role_name, + dc.actions, + dc.access_type, + dc.access_provider_id, + dc.access_provider_role_id, + dc.access_provider_role_name, + dc.access_provider_role_actions + FROM + direct_clients AS dc + ), + final_clients AS ( + SELECT + gdc.id, + gdc.name, + gdc.domain_id, + gdc.parent_group_id, + gdc.identity, + gdc.secret, + gdc.tags, + gdc.metadata, + gdc.created_at, + gdc.updated_at, + gdc.updated_by, + gdc.status, + gdc.parent_group_path, + gdc.role_id, + gdc.role_name, + gdc.actions, + gdc.access_type, + gdc.access_provider_id, + gdc.access_provider_role_id, + gdc.access_provider_role_name, + gdc.access_provider_role_actions + FROM + group_direct_clients AS gdc + ) + `, userID, domainID, userID, domainID, domainID) +} + +func (repo *clientRepo) SearchClients(ctx context.Context, pm clients.Page) (clients.ClientsPage, error) { query, err := PageQuery(pm) if err != nil { return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } + + tq := query query = applyOrdering(query, pm) - q := fmt.Sprintf(`SELECT c.id, c.name, c.tags, c.identity, c.metadata, COALESCE(c.domain_id, '') AS domain_id, COALESCE(parent_group_id, '') AS parent_group_id, c.status, - c.created_at, c.updated_at, COALESCE(c.updated_by, '') AS updated_by FROM clients c %s ORDER BY c.created_at LIMIT :limit OFFSET :offset;`, query) + q := fmt.Sprintf(`SELECT c.id, c.name, c.created_at, c.updated_at FROM clients c %s LIMIT :limit OFFSET :offset;`, query) dbPage, err := ToDBClientsPage(pm) if err != nil { return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) } + rows, err := repo.DB.NamedQueryContext(ctx, q, dbPage) if err != nil { return clients.ClientsPage{}, errors.Wrap(repoerr.ErrFailedToRetrieveAllGroups, err) @@ -341,8 +628,8 @@ func (repo *clientRepo) RetrieveAllByIDs(ctx context.Context, pm clients.Page) ( items = append(items, c) } - cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, query) + cq := fmt.Sprintf(`SELECT COUNT(*) FROM clients c %s;`, tq) total, err := postgres.Total(ctx, repo.DB, cq, dbPage) if err != nil { return clients.ClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) @@ -402,18 +689,27 @@ func (repo *clientRepo) Delete(ctx context.Context, clientIDs ...string) error { } type DBClient struct { - ID string `db:"id"` - Name string `db:"name,omitempty"` - Tags pgtype.TextArray `db:"tags,omitempty"` - Identity string `db:"identity"` - Domain string `db:"domain_id"` - ParentGroup sql.NullString `db:"parent_group_id,omitempty"` - Secret string `db:"secret"` - Metadata []byte `db:"metadata,omitempty"` - CreatedAt time.Time `db:"created_at,omitempty"` - UpdatedAt sql.NullTime `db:"updated_at,omitempty"` - UpdatedBy *string `db:"updated_by,omitempty"` - Status clients.Status `db:"status,omitempty"` + ID string `db:"id"` + Name string `db:"name,omitempty"` + Tags pgtype.TextArray `db:"tags,omitempty"` + Identity string `db:"identity"` + Domain string `db:"domain_id"` + ParentGroup sql.NullString `db:"parent_group_id,omitempty"` + Secret string `db:"secret"` + Metadata []byte `db:"metadata,omitempty"` + CreatedAt time.Time `db:"created_at,omitempty"` + UpdatedAt sql.NullTime `db:"updated_at,omitempty"` + UpdatedBy *string `db:"updated_by,omitempty"` + Status clients.Status `db:"status,omitempty"` + ParentGroupPath string `db:"parent_group_path,omitempty"` + RoleID string `db:"role_id,omitempty"` + RoleName string `db:"role_name,omitempty"` + Actions pq.StringArray `db:"actions,omitempty"` + AccessType string `db:"access_type,omitempty"` + AccessProviderId string `db:"access_provider_id,omitempty"` + AccessProviderRoleId string `db:"access_provider_role_id,omitempty"` + AccessProviderRoleName string `db:"access_provider_role_name,omitempty"` + AccessProviderRoleActions pq.StringArray `db:"access_provider_role_actions,omitempty"` } func ToDBClient(c clients.Client) (DBClient, error) { @@ -484,11 +780,19 @@ func ToClient(t DBClient) (clients.Client, error) { Identity: t.Identity, Secret: t.Secret, }, - Metadata: metadata, - CreatedAt: t.CreatedAt, - UpdatedAt: updatedAt, - UpdatedBy: updatedBy, - Status: t.Status, + Metadata: metadata, + CreatedAt: t.CreatedAt, + UpdatedAt: updatedAt, + UpdatedBy: updatedBy, + Status: t.Status, + RoleID: t.RoleID, + RoleName: t.RoleName, + Actions: t.Actions, + AccessType: t.AccessType, + AccessProviderId: t.AccessProviderId, + AccessProviderRoleId: t.AccessProviderRoleId, + AccessProviderRoleName: t.AccessProviderRoleName, + AccessProviderRoleActions: t.AccessProviderRoleActions, } return cli, nil } @@ -499,31 +803,42 @@ func ToDBClientsPage(pm clients.Page) (dbClientsPage, error) { return dbClientsPage{}, errors.Wrap(repoerr.ErrViewEntity, err) } return dbClientsPage{ - Name: pm.Name, - Identity: pm.Identity, - Id: pm.Id, - Metadata: data, - Domain: pm.Domain, - Total: pm.Total, - Offset: pm.Offset, - Limit: pm.Limit, - Status: pm.Status, - Tag: pm.Tag, + Offset: pm.Offset, + Limit: pm.Limit, + Name: pm.Name, + Identity: pm.Identity, + Id: pm.Id, + Metadata: data, + Domain: pm.Domain, + Status: pm.Status, + Tag: pm.Tag, + GroupID: pm.Group, + ChannelID: pm.Channel, + RoleName: pm.RoleName, + ConnType: pm.ConnectionType, + RoleID: pm.RoleID, + Actions: pm.Actions, + AccessType: pm.AccessType, }, nil } type dbClientsPage struct { - Total uint64 `db:"total"` - Limit uint64 `db:"limit"` - Offset uint64 `db:"offset"` - Name string `db:"name"` - Id string `db:"id"` - Domain string `db:"domain_id"` - Identity string `db:"identity"` - Metadata []byte `db:"metadata"` - Tag string `db:"tag"` - Status clients.Status `db:"status"` - GroupID string `db:"group_id"` + Limit uint64 `db:"limit"` + Offset uint64 `db:"offset"` + Name string `db:"name"` + Id string `db:"id"` + Domain string `db:"domain_id"` + Identity string `db:"identity"` + Metadata []byte `db:"metadata"` + Tag string `db:"tag"` + Status clients.Status `db:"status"` + GroupID string `db:"group_id"` + ChannelID string `db:"channel_id"` + ConnType string `db:"type"` + RoleName string `db:"role_name"` + RoleID string `db:"role_id"` + Actions pq.StringArray `db:"actions"` + AccessType string `db:"access_type"` } func PageQuery(pm clients.Page) (string, error) { @@ -534,36 +849,57 @@ func PageQuery(pm clients.Page) (string, error) { var query []string if pm.Name != "" { - query = append(query, "name ILIKE '%' || :name || '%'") + query = append(query, "c.name ILIKE '%' || :name || '%'") } if pm.Identity != "" { - query = append(query, "identity ILIKE '%' || :identity || '%'") + query = append(query, "c.identity ILIKE '%' || :identity || '%'") } if pm.Id != "" { - query = append(query, "id ILIKE '%' || :id || '%'") + query = append(query, "c.id ILIKE '%' || :id || '%'") } if pm.Tag != "" { query = append(query, "EXISTS (SELECT 1 FROM unnest(tags) AS tag WHERE tag ILIKE '%' || :tag || '%')") } - // If there are search params presents, use search and ignore other options. - // Always combine role with search params, so len(query) > 1. - if len(query) > 1 { - return fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")), nil - } if mq != "" { query = append(query, mq) } if len(pm.IDs) != 0 { - query = append(query, fmt.Sprintf("id IN ('%s')", strings.Join(pm.IDs, "','"))) + query = append(query, fmt.Sprintf("c.id IN ('%s')", strings.Join(pm.IDs, "','"))) } + if pm.Status != clients.AllStatus { query = append(query, "c.status = :status") } if pm.Domain != "" { query = append(query, "c.domain_id = :domain_id") } + if pm.Group != "" { + query = append(query, "c.parent_group_path @> (SELECT path from groups where id = :group_id) ") + } + if pm.Channel != "" { + query = append(query, "conn.channel_id = :channel_id ") + if pm.ConnectionType != "" { + query = append(query, "conn.type = :conn_type ") + } + } + if pm.AccessType != "" { + query = append(query, "c.access_type = :access_type") + } + if pm.RoleID != "" { + query = append(query, "c.role_id = :role_id") + } + if pm.RoleName != "" { + query = append(query, "c.role_name = :role_name") + } + if len(pm.Actions) != 0 { + query = append(query, "c.actions @> :actions") + } + if len(pm.Metadata) > 0 { + query = append(query, "c.metadata @> :metadata") + } + var emq string if len(query) > 0 { emq = fmt.Sprintf("WHERE %s", strings.Join(query, " AND ")) diff --git a/clients/postgres/clients_test.go b/clients/postgres/clients_test.go index ea37a3d9d7..8a6e46d0d5 100644 --- a/clients/postgres/clients_test.go +++ b/clients/postgres/clients_test.go @@ -1626,274 +1626,7 @@ func TestSearchClients(t *testing.T) { } } -func TestRetrieveAllByIDs(t *testing.T) { - t.Cleanup(func() { - _, err := db.Exec("DELETE FROM clients") - require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err)) - }) - - repo := postgres.NewRepository(database) - - num := 200 - - var items []clients.Client - for i := 0; i < num; i++ { - name := namegen.Generate() - client := clients.Client{ - ID: testsutil.GenerateUUID(t), - Domain: testsutil.GenerateUUID(t), - Name: name, - Credentials: clients.Credentials{ - Identity: name + emailSuffix, - Secret: testsutil.GenerateUUID(t), - }, - Tags: namegen.GenerateMultiple(5), - Metadata: map[string]interface{}{"name": name}, - CreatedAt: time.Now().UTC().Truncate(time.Microsecond), - Status: clients.EnabledStatus, - } - _, err := repo.Save(context.Background(), client) - require.Nil(t, err, fmt.Sprintf("add new client: expected nil got %s\n", err)) - items = append(items, client) - } - - page, err := repo.RetrieveAll(context.Background(), clients.Page{Offset: 0, Limit: uint64(num)}) - require.Nil(t, err, fmt.Sprintf("retrieve all clients unexpected error: %s", err)) - assert.Equal(t, uint64(num), page.Total) - - cases := []struct { - desc string - page clients.Page - response clients.ClientsPage - err error - }{ - { - desc: "successfully", - page: clients.Page{ - Offset: 0, - Limit: 10, - IDs: getIDs(items[0:3]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 3, - Offset: 0, - Limit: 10, - }, - Clients: items[0:3], - }, - err: nil, - }, - { - desc: "with empty ids", - page: clients.Page{ - Offset: 0, - Limit: 10, - IDs: []string{}, - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Offset: 0, - Limit: 10, - }, - Clients: []clients.Client(nil), - }, - err: nil, - }, - { - desc: "with empty ids but with domain id", - page: clients.Page{ - Offset: 0, - Limit: 10, - Domain: items[0].Domain, - IDs: []string{}, - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 1, - Offset: 0, - Limit: 10, - }, - Clients: []clients.Client{items[0]}, - }, - err: nil, - }, - { - desc: "with offset only", - page: clients.Page{ - Offset: 10, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 20, - Offset: 10, - Limit: 0, - }, - Clients: []clients.Client(nil), - }, - err: nil, - }, - { - desc: "with limit only", - page: clients.Page{ - Limit: 10, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 20, - Offset: 0, - Limit: 10, - }, - Clients: items[0:10], - }, - err: nil, - }, - { - desc: "with offset out of range", - page: clients.Page{ - Offset: 1000, - Limit: 50, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 20, - Offset: 1000, - Limit: 50, - }, - Clients: []clients.Client(nil), - }, - err: nil, - }, - { - desc: "with offset and limit out of range", - page: clients.Page{ - Offset: 15, - Limit: 10, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 20, - Offset: 15, - Limit: 10, - }, - Clients: items[15:20], - }, - err: nil, - }, - { - desc: "with limit out of range", - page: clients.Page{ - Offset: 0, - Limit: 1000, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 20, - Offset: 0, - Limit: 1000, - }, - Clients: items[:20], - }, - err: nil, - }, - { - desc: "with name", - page: clients.Page{ - Offset: 0, - Limit: 10, - Name: items[0].Name, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 1, - Offset: 0, - Limit: 10, - }, - Clients: []clients.Client{items[0]}, - }, - err: nil, - }, - { - desc: "with domain id", - page: clients.Page{ - Offset: 0, - Limit: 10, - Domain: items[0].Domain, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 1, - Offset: 0, - Limit: 10, - }, - Clients: []clients.Client{items[0]}, - }, - err: nil, - }, - { - desc: "with metadata", - page: clients.Page{ - Offset: 0, - Limit: 10, - Metadata: items[0].Metadata, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 1, - Offset: 0, - Limit: 10, - }, - Clients: []clients.Client{items[0]}, - }, - err: nil, - }, - { - desc: "with invalid metadata", - page: clients.Page{ - Offset: 0, - Limit: 10, - Metadata: map[string]interface{}{ - "key": make(chan int), - }, - IDs: getIDs(items[0:20]), - }, - response: clients.ClientsPage{ - Page: clients.Page{ - Total: 0, - Offset: 0, - Limit: 10, - }, - Clients: []clients.Client(nil), - }, - err: errors.ErrMalformedEntity, - }, - } - - for _, c := range cases { - switch response, err := repo.RetrieveAllByIDs(context.Background(), c.page); { - case err == nil: - assert.Nil(t, err, fmt.Sprintf("%s: expected %s got %s\n", c.desc, c.err, err)) - assert.Equal(t, c.response.Total, response.Total) - assert.Equal(t, c.response.Limit, response.Limit) - assert.Equal(t, c.response.Offset, response.Offset) - expected := stripClientDetails(c.response.Clients) - got := stripClientDetails(response.Clients) - assert.ElementsMatch(t, expected, got) - default: - assert.True(t, errors.Contains(err, c.err), fmt.Sprintf("expected %s to contain %s\n", err, c.err)) - } - } -} - -func TestRetrievByIDs(t *testing.T) { +func TestRetrieveByIDs(t *testing.T) { t.Cleanup(func() { _, err := db.Exec("DELETE FROM clients") require.Nil(t, err, fmt.Sprintf("clean clients unexpected error: %s", err)) diff --git a/clients/postgres/init.go b/clients/postgres/init.go index 11b0110ef1..0a3d347c66 100644 --- a/clients/postgres/init.go +++ b/clients/postgres/init.go @@ -4,6 +4,7 @@ package postgres import ( + gpostgres "github.com/absmach/supermq/groups/postgres" "github.com/absmach/supermq/pkg/errors" repoerr "github.com/absmach/supermq/pkg/errors/repository" rolesPostgres "github.com/absmach/supermq/pkg/roles/repo/postgres" @@ -60,5 +61,12 @@ func Migration() (*migrate.MemoryMigrationSource, error) { clientsMigration.Migrations = append(clientsMigration.Migrations, clientsRolesMigration.Migrations...) + groupsMigration, err := gpostgres.Migration() + if err != nil { + return &migrate.MemoryMigrationSource{}, err + } + + clientsMigration.Migrations = append(clientsMigration.Migrations, groupsMigration.Migrations...) + return clientsMigration, nil } diff --git a/clients/roleoperations.go b/clients/roleoperations.go index 88a6d16ed3..9636c12213 100644 --- a/clients/roleoperations.go +++ b/clients/roleoperations.go @@ -144,7 +144,7 @@ func NewRolesOperationPermissionMap() map[svcutil.Operation]svcutil.Permission { const ( // External Permission for domains. domainCreateClientPermission = "client_create_permission" - domainListClientsPermission = "list_clients_permission" + domainListClientsPermission = "client_read_permission" // External Permission for groups. groupSetChildClientPermission = "client_create_permission" groupRemoveChildClientPermission = "client_create_permission" diff --git a/clients/service.go b/clients/service.go index 471e30a680..e563a886fd 100644 --- a/clients/service.go +++ b/clients/service.go @@ -12,13 +12,11 @@ import ( grpcCommonV1 "github.com/absmach/supermq/api/grpc/common/v1" grpcGroupsV1 "github.com/absmach/supermq/api/grpc/groups/v1" apiutil "github.com/absmach/supermq/api/http/util" - smqauth "github.com/absmach/supermq/auth" "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" "github.com/absmach/supermq/pkg/roles" - "golang.org/x/sync/errgroup" ) var ( @@ -131,113 +129,29 @@ func (svc service) View(ctx context.Context, session authn.Session, id string) ( return client, nil } -func (svc service) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm Page) (ClientsPage, error) { - var ids []string - var err error - switch { - case (reqUserID != "" && reqUserID != session.UserID): - rtids, err := svc.listClientIDs(ctx, smqauth.EncodeDomainUserID(session.DomainID, reqUserID), pm.Permission) +func (svc service) ListClients(ctx context.Context, session authn.Session, pm Page) (ClientsPage, error) { + switch session.SuperAdmin { + case true: + cp, err := svc.repo.RetrieveAll(ctx, pm) if err != nil { - return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err) - } - ids, err = svc.filterAllowedClientIDs(ctx, session.DomainUserID, pm.Permission, rtids) - if err != nil { - return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err) + return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err) } + return cp, nil default: - switch session.SuperAdmin { - case true: - pm.Domain = session.DomainID - default: - ids, err = svc.listClientIDs(ctx, session.DomainUserID, pm.Permission) - if err != nil { - return ClientsPage{}, errors.Wrap(svcerr.ErrNotFound, err) - } - } - } - - if len(ids) == 0 && pm.Domain == "" { - return ClientsPage{}, nil - } - pm.IDs = ids - tp, err := svc.repo.SearchClients(ctx, pm) - if err != nil { - return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err) - } - - if pm.ListPerms && len(tp.Clients) > 0 { - g, ctx := errgroup.WithContext(ctx) - - for i := range tp.Clients { - // Copying loop variable "i" to avoid "loop variable captured by func literal" - iter := i - g.Go(func() error { - return svc.retrievePermissions(ctx, session.DomainUserID, &tp.Clients[iter]) - }) - } - - if err := g.Wait(); err != nil { - return ClientsPage{}, err + cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, session.UserID, pm) + if err != nil { + return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err) } + return cp, nil } - return tp, nil -} - -// Experimental functions used for async calling of svc.listUserClientPermission. This might be helpful during listing of large number of entities. -func (svc service) retrievePermissions(ctx context.Context, userID string, client *Client) error { - permissions, err := svc.listUserClientPermission(ctx, userID, client.ID) - if err != nil { - return err - } - client.Permissions = permissions - return nil -} - -func (svc service) listUserClientPermission(ctx context.Context, userID, clientID string) ([]string, error) { - permissions, err := svc.policy.ListPermissions(ctx, policies.Policy{ - SubjectType: policies.UserType, - Subject: userID, - Object: clientID, - ObjectType: policies.ClientType, - }, []string{}) - if err != nil { - return []string{}, errors.Wrap(svcerr.ErrAuthorization, err) - } - return permissions, nil } -func (svc service) listClientIDs(ctx context.Context, userID, permission string) ([]string, error) { - tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{ - SubjectType: policies.UserType, - Subject: userID, - Permission: permission, - ObjectType: policies.ClientType, - }) +func (svc service) ListUserClients(ctx context.Context, session authn.Session, userID string, pm Page) (ClientsPage, error) { + cp, err := svc.repo.RetrieveUserClients(ctx, session.DomainID, userID, pm) if err != nil { - return nil, errors.Wrap(svcerr.ErrNotFound, err) - } - return tids.Policies, nil -} - -func (svc service) filterAllowedClientIDs(ctx context.Context, userID, permission string, clientIDs []string) ([]string, error) { - var ids []string - tids, err := svc.policy.ListAllObjects(ctx, policies.Policy{ - SubjectType: policies.UserType, - Subject: userID, - Permission: permission, - ObjectType: policies.ClientType, - }) - if err != nil { - return nil, errors.Wrap(svcerr.ErrNotFound, err) - } - for _, clientID := range clientIDs { - for _, tid := range tids.Policies { - if clientID == tid { - ids = append(ids, clientID) - } - } + return ClientsPage{}, errors.Wrap(svcerr.ErrViewEntity, err) } - return ids, nil + return cp, nil } func (svc service) Update(ctx context.Context, session authn.Session, cli Client) (Client, error) { diff --git a/clients/service_test.go b/clients/service_test.go index 9683bdc857..e3e51b6f4f 100644 --- a/clients/service_test.go +++ b/clients/service_test.go @@ -351,7 +351,6 @@ func TestListClients(t *testing.T) { adminID := testsutil.GenerateUUID(t) domainID := testsutil.GenerateUUID(t) nonAdminID := testsutil.GenerateUUID(t) - client.Permissions = []string{"read", "edit"} cases := []struct { desc string @@ -375,9 +374,8 @@ func TestListClients(t *testing.T) { session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, id: nonAdminID, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, + Offset: 0, + Limit: 100, }, listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}}, retrieveAllResponse: clients.ClientsPage{ @@ -388,7 +386,6 @@ func TestListClients(t *testing.T) { }, Clients: []clients.Client{client, client}, }, - listPermissionsResponse: client.Permissions, response: clients.ClientsPage{ Page: clients.Page{ Total: 2, @@ -405,9 +402,8 @@ func TestListClients(t *testing.T) { session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, id: nonAdminID, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, + Offset: 0, + Limit: 100, }, listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}}, retrieveAllResponse: clients.ClientsPage{}, @@ -415,39 +411,14 @@ func TestListClients(t *testing.T) { retrieveAllErr: repoerr.ErrNotFound, err: svcerr.ErrNotFound, }, - { - desc: "list all clients as non admin with failed to list permissions", - userKind: "non-admin", - session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, - id: nonAdminID, - page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, - }, - listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}}, - retrieveAllResponse: clients.ClientsPage{ - Page: clients.Page{ - Total: 2, - Offset: 0, - Limit: 100, - }, - Clients: []clients.Client{client, client}, - }, - listPermissionsResponse: []string{}, - response: clients.ClientsPage{}, - listPermissionsErr: svcerr.ErrNotFound, - err: svcerr.ErrNotFound, - }, { desc: "list all clients as non admin with failed super admin", userKind: "non-admin", session: smqauthn.Session{UserID: nonAdminID, DomainID: domainID, SuperAdmin: false}, id: nonAdminID, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, + Offset: 0, + Limit: 100, }, response: clients.ClientsPage{}, listObjectsResponse: policysvc.PolicyPage{}, @@ -458,10 +429,10 @@ func TestListClients(t *testing.T) { userKind: "non-admin", id: nonAdminID, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, + Offset: 0, + Limit: 100, }, + retrieveAllErr: repoerr.ErrNotFound, response: clients.ClientsPage{}, listObjectsResponse: policysvc.PolicyPage{}, listObjectsErr: svcerr.ErrNotFound, @@ -470,15 +441,13 @@ func TestListClients(t *testing.T) { } for _, tc := range cases { - listAllObjectsCall := pService.On("ListAllObjects", mock.Anything, mock.Anything).Return(tc.listObjectsResponse, tc.listObjectsErr) - retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) - listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr) - page, err := svc.ListClients(context.Background(), tc.session, tc.id, tc.page) + retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) + retrieveUserClientsCall := repo.On("RetrieveUserClients", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) + page, err := svc.ListClients(context.Background(), tc.session, tc.page) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) - listAllObjectsCall.Unset() retrieveAllCall.Unset() - listPermissionsCall.Unset() + retrieveUserClientsCall.Unset() } cases2 := []struct { @@ -503,10 +472,9 @@ func TestListClients(t *testing.T) { id: adminID, session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, - Domain: domainID, + Offset: 0, + Limit: 100, + Domain: domainID, }, listObjectsResponse: policysvc.PolicyPage{Policies: []string{client.ID, client.ID}}, retrieveAllResponse: clients.ClientsPage{ @@ -517,7 +485,6 @@ func TestListClients(t *testing.T) { }, Clients: []clients.Client{client, client}, }, - listPermissionsResponse: client.Permissions, response: clients.ClientsPage{ Page: clients.Page{ Total: 2, @@ -534,50 +501,24 @@ func TestListClients(t *testing.T) { id: adminID, session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, - Domain: domainID, + Offset: 0, + Limit: 100, + Domain: domainID, }, listObjectsResponse: policysvc.PolicyPage{}, retrieveAllResponse: clients.ClientsPage{}, retrieveAllErr: repoerr.ErrNotFound, err: svcerr.ErrNotFound, }, - { - desc: "list all clients as admin with failed to list permissions", - userKind: "admin", - id: adminID, - session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, - page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, - Domain: domainID, - }, - listObjectsResponse: policysvc.PolicyPage{}, - retrieveAllResponse: clients.ClientsPage{ - Page: clients.Page{ - Total: 2, - Offset: 0, - Limit: 100, - }, - Clients: []clients.Client{client, client}, - }, - listPermissionsResponse: []string{}, - listPermissionsErr: svcerr.ErrNotFound, - err: svcerr.ErrNotFound, - }, { desc: "list all clients as admin with failed to list clients", userKind: "admin", id: adminID, session: smqauthn.Session{UserID: adminID, DomainID: domainID, SuperAdmin: true}, page: clients.Page{ - Offset: 0, - Limit: 100, - ListPerms: true, - Domain: domainID, + Offset: 0, + Limit: 100, + Domain: domainID, }, retrieveAllResponse: clients.ClientsPage{}, retrieveAllErr: repoerr.ErrNotFound, @@ -586,27 +527,11 @@ func TestListClients(t *testing.T) { } for _, tc := range cases2 { - listAllObjectsCall := pService.On("ListAllObjects", context.Background(), policysvc.Policy{ - SubjectType: policysvc.UserType, - Subject: tc.session.DomainID + "_" + adminID, - Permission: "", - ObjectType: policysvc.ClientType, - }).Return(tc.listObjectsResponse, tc.listObjectsErr) - listAllObjectsCall2 := pService.On("ListAllObjects", context.Background(), policysvc.Policy{ - SubjectType: policysvc.UserType, - Subject: tc.session.UserID, - Permission: "", - ObjectType: policysvc.ClientType, - }).Return(tc.listObjectsResponse, tc.listObjectsErr) - retrieveAllCall := repo.On("SearchClients", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) - listPermissionsCall := pService.On("ListPermissions", mock.Anything, mock.Anything, mock.Anything).Return(tc.listPermissionsResponse, tc.listPermissionsErr) - page, err := svc.ListClients(context.Background(), tc.session, tc.id, tc.page) + retrieveAllCall := repo.On("RetrieveAll", mock.Anything, mock.Anything).Return(tc.retrieveAllResponse, tc.retrieveAllErr) + page, err := svc.ListClients(context.Background(), tc.session, tc.page) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) - listAllObjectsCall.Unset() - listAllObjectsCall2.Unset() retrieveAllCall.Unset() - listPermissionsCall.Unset() } } diff --git a/clients/status_test.go b/clients/status_test.go index 1e0c4b6c6f..4b60afa610 100644 --- a/clients/status_test.go +++ b/clients/status_test.go @@ -197,7 +197,7 @@ func TestStatusUnmarshalJSON(t *testing.T) { } } -func TestUserMarshalJSON(t *testing.T) { +func TestClientMarshalJSON(t *testing.T) { cases := []struct { desc string expected []byte diff --git a/clients/tracing/tracing.go b/clients/tracing/tracing.go index fd2b266220..df69f2891c 100644 --- a/clients/tracing/tracing.go +++ b/clients/tracing/tracing.go @@ -47,10 +47,16 @@ func (tm *tracingMiddleware) View(ctx context.Context, session authn.Session, id } // ListClients traces the "ListClients" operation of the wrapped clients.Service. -func (tm *tracingMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) { +func (tm *tracingMiddleware) ListClients(ctx context.Context, session authn.Session, pm clients.Page) (clients.ClientsPage, error) { ctx, span := tm.tracer.Start(ctx, "svc_list_clients") defer span.End() - return tm.svc.ListClients(ctx, session, reqUserID, pm) + return tm.svc.ListClients(ctx, session, pm) +} + +func (tm *tracingMiddleware) ListUserClients(ctx context.Context, session authn.Session, userID string, pm clients.Page) (clients.ClientsPage, error) { + ctx, span := tm.tracer.Start(ctx, "svc_list_clients") + defer span.End() + return tm.svc.ListUserClients(ctx, session, userID, pm) } // Update traces the "Update" operation of the wrapped clients.Service. diff --git a/cmd/channels/main.go b/cmd/channels/main.go index 3157343cc5..2afe764680 100644 --- a/cmd/channels/main.go +++ b/cmd/channels/main.go @@ -25,11 +25,13 @@ import ( "github.com/absmach/supermq/channels/postgres" pChannels "github.com/absmach/supermq/channels/private" "github.com/absmach/supermq/channels/tracing" + gpostgres "github.com/absmach/supermq/groups/postgres" smqlog "github.com/absmach/supermq/logger" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer" "github.com/absmach/supermq/pkg/grpcclient" jaegerclient "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/policies" @@ -74,6 +76,7 @@ type config struct { JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` + ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"channels"` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` @@ -224,6 +227,15 @@ func main() { return } + gdatabase := pg.NewDatabase(db, dbConfig, tracer) + grepo := gpostgres.New(gdatabase) + + if err := gconsumer.GroupsEventsSubscribe(ctx, grepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create groups event store : %s", err)) + exitCode = 1 + return + } + grpcServerConfig := server.Config{Port: defSvcGRPCPort} if err := env.ParseWithOptions(&grpcServerConfig, env.Options{Prefix: envPrefixGRPC}); err != nil { logger.Error(fmt.Sprintf("failed to load %s gRPC server configuration : %s", svcName, err)) diff --git a/cmd/clients/main.go b/cmd/clients/main.go index 94b9c402d3..de3b089f39 100644 --- a/cmd/clients/main.go +++ b/cmd/clients/main.go @@ -27,12 +27,14 @@ import ( "github.com/absmach/supermq/clients/postgres" pClients "github.com/absmach/supermq/clients/private" "github.com/absmach/supermq/clients/tracing" + gpostgres "github.com/absmach/supermq/groups/postgres" redisclient "github.com/absmach/supermq/internal/clients/redis" smqlog "github.com/absmach/supermq/logger" authsvcAuthn "github.com/absmach/supermq/pkg/authn/authsvc" smqauthz "github.com/absmach/supermq/pkg/authz" authsvcAuthz "github.com/absmach/supermq/pkg/authz/authsvc" domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" + gconsumer "github.com/absmach/supermq/pkg/groups/events/consumer" "github.com/absmach/supermq/pkg/grpcclient" jaegerclient "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/policies" @@ -82,6 +84,7 @@ type config struct { JaegerURL url.URL `env:"SMQ_JAEGER_URL" envDefault:"http://localhost:4318/v1/traces"` SendTelemetry bool `env:"SMQ_SEND_TELEMETRY" envDefault:"true"` ESURL string `env:"SMQ_ES_URL" envDefault:"nats://localhost:4222"` + ESConsumerName string `env:"SMQ_CLIENTS_EVENT_CONSUMER" envDefault:"clients"` TraceRatio float64 `env:"SMQ_JAEGER_TRACE_RATIO" envDefault:"1.0"` SpicedbHost string `env:"SMQ_SPICEDB_HOST" envDefault:"localhost"` SpicedbPort string `env:"SMQ_SPICEDB_PORT" envDefault:"50051"` @@ -241,6 +244,15 @@ func main() { return } + gdatabase := pg.NewDatabase(db, dbConfig, tracer) + grepo := gpostgres.New(gdatabase) + + if err := gconsumer.GroupsEventsSubscribe(ctx, grepo, cfg.ESURL, cfg.ESConsumerName, logger); err != nil { + logger.Error(fmt.Sprintf("failed to create groups event store : %s", err)) + exitCode = 1 + return + } + httpServerConfig := server.Config{Port: defSvcHTTPPort} if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { logger.Error(fmt.Sprintf("failed to load %s HTTP server configuration : %s", svcName, err)) diff --git a/groups/events/events.go b/groups/events/events.go index dd8e565468..8dc117e3a2 100644 --- a/groups/events/events.go +++ b/groups/events/events.go @@ -357,13 +357,13 @@ type addChildrenGroupsEvent struct { func (acge addChildrenGroupsEvent) Encode() (map[string]interface{}, error) { return map[string]interface{}{ - "operation": groupAddChildrenGroups, - "id": acge.id, - "childre_ids": acge.childrenIDs, - "domain": acge.DomainID, - "user_id": acge.UserID, - "token_type": acge.Type.String(), - "super_admin": acge.SuperAdmin, + "operation": groupAddChildrenGroups, + "id": acge.id, + "children_ids": acge.childrenIDs, + "domain": acge.DomainID, + "user_id": acge.UserID, + "token_type": acge.Type.String(), + "super_admin": acge.SuperAdmin, }, nil } diff --git a/groups/groups.go b/groups/groups.go index b9bf07046f..3a442f2335 100644 --- a/groups/groups.go +++ b/groups/groups.go @@ -142,9 +142,10 @@ type Service interface { // ViewGroup retrieves data about the group identified by ID. ViewGroup(ctx context.Context, session authn.Session, id string) (Group, error) - // ListGroups retrieves + // ListGroups retrieves groups for given filters. ListGroups(ctx context.Context, session authn.Session, pm PageMeta) (Page, error) + // ListGroups retrieves user accessible groups for given filters. ListUserGroups(ctx context.Context, session authn.Session, userID string, pm PageMeta) (Page, error) // EnableGroup logically enables the group identified with the provided ID. diff --git a/groups/middleware/authorization.go b/groups/middleware/authorization.go index 2f9c01687f..668baa01e6 100644 --- a/groups/middleware/authorization.go +++ b/groups/middleware/authorization.go @@ -73,6 +73,7 @@ func AuthorizationMiddleware(entityType string, svc groups.Service, repo groups. } return &authorizationMiddleware{ svc: svc, + repo: repo, authz: authz, opp: opp, extOpp: extOpp, diff --git a/pkg/errors/repository/types.go b/pkg/errors/repository/types.go index d468519e16..2f068e1e1f 100644 --- a/pkg/errors/repository/types.go +++ b/pkg/errors/repository/types.go @@ -34,7 +34,8 @@ var ( // ErrFailedToRetrieveAllGroups failed to retrieve groups. ErrFailedToRetrieveAllGroups = errors.New("failed to retrieve all groups") - ErrRoleMigration = errors.New("role migration initialization failed") + // ErrRoleMigration failed to apply role migrations. + ErrRoleMigration = errors.New("failed to apply role migration") // ErrMissingNames indicates missing first and last names. ErrMissingNames = errors.New("missing first or last name") diff --git a/pkg/events/events.go b/pkg/events/events.go index 65845a785c..e695099733 100644 --- a/pkg/events/events.go +++ b/pkg/events/events.go @@ -43,6 +43,7 @@ type SubscriberConfig struct { Consumer string Stream string Handler EventHandler + Ordered bool } // Subscriber specifies event subscription API. diff --git a/pkg/events/nats/subscriber.go b/pkg/events/nats/subscriber.go index 95e78bc57b..2ebb9849f9 100644 --- a/pkg/events/nats/subscriber.go +++ b/pkg/events/nats/subscriber.go @@ -93,6 +93,7 @@ func (es *subEventStore) Subscribe(ctx context.Context, cfg events.SubscriberCon logger: es.logger, }, DeliveryPolicy: messaging.DeliverNewPolicy, + Ordered: cfg.Ordered, } return es.pubsub.Subscribe(ctx, subCfg) @@ -126,8 +127,9 @@ func (eh *eventHandler) Handle(msg *messaging.Message) error { return err } - if err := eh.handler.Handle(eh.ctx, event); err != nil { - eh.logger.Warn(fmt.Sprintf("failed to handle nats event: %s", err)) + err := eh.handler.Handle(eh.ctx, event) + if err != nil { + return fmt.Errorf("failed to handle nats event: %s", err) } return nil diff --git a/pkg/groups/events/consumer/decode.go b/pkg/groups/events/consumer/decode.go new file mode 100644 index 0000000000..84371d5f95 --- /dev/null +++ b/pkg/groups/events/consumer/decode.go @@ -0,0 +1,255 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "time" + + "github.com/absmach/supermq/groups" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/roles" + rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer" +) + +var ( + errDecodeCreateGroupEvent = errors.New("failed to decode group create event") + errDecodeUpdateGroupEvent = errors.New("failed to decode group update event") + errDecodeChangeStatusGroupEvent = errors.New("failed to decode group change status event") + errDecodeRemoveGroupEvent = errors.New("failed to decode group remove event") + errDecodeAddParentGroupEvent = errors.New("failed to decode group add parent event") + errDecodeRemoveParentGroupEvent = errors.New("failed to decode group remove parent event") + errDecodeAddChildrenGroupsEvent = errors.New("failed to decode group add children groups event") + errDecodeRemoveChildrenGroupsEvent = errors.New("failed to decode group remove children groups event") + + errID = errors.New("missing or invalid 'id'") + errName = errors.New("missing or invalid 'name'") + errDomain = errors.New("missing or invalid 'domain'") + errParent = errors.New("missing or invalid 'parent'") + errChildrenIDs = errors.New("missing or invalid 'children_ids'") + errStatus = errors.New("missing or invalid 'status'") + errConvertStatus = errors.New("failed to convert status") + errCreatedAt = errors.New("failed to parse 'created_at' time") + errUpdatedAt = errors.New("failed to parse 'updated_at' time") +) + +const ( + layout = "2006-01-02T15:04:05.999999Z" +) + +func ToGroups(data map[string]interface{}) (groups.Group, error) { + var g groups.Group + id, ok := data["id"].(string) + if !ok { + return groups.Group{}, errID + } + g.ID = id + + name, ok := data["name"].(string) + if !ok { + return groups.Group{}, errName + } + g.Name = name + + dom, ok := data["domain"].(string) + if !ok { + return groups.Group{}, errDomain + } + g.Domain = dom + + stat, ok := data["status"].(string) + if !ok { + return groups.Group{}, errStatus + } + st, err := groups.ToStatus(stat) + if err != nil { + return groups.Group{}, errors.Wrap(errConvertStatus, err) + } + g.Status = st + + cat, ok := data["created_at"].(string) + if !ok { + return groups.Group{}, errCreatedAt + } + ct, err := time.Parse(layout, cat) + if err != nil { + return groups.Group{}, errors.Wrap(errCreatedAt, err) + } + g.CreatedAt = ct + + // Following fields of groups are allowed to be empty. + + desc, ok := data["description"].(string) + if ok { + g.Description = desc + } + + parent, ok := data["parent"].(string) + if ok { + g.Parent = parent + } + + meta, ok := data["metadata"].(map[string]interface{}) + if ok { + g.Metadata = meta + } + + uby, ok := data["updated_by"].(string) + if ok { + g.UpdatedBy = uby + } + + uat, ok := data["updated_at"].(string) + if ok { + ut, err := time.Parse(layout, uat) + if err != nil { + return groups.Group{}, errors.Wrap(errUpdatedAt, err) + } + g.UpdatedAt = ut + } + + return g, nil +} + +func decodeCreateGroupEvent(data map[string]interface{}) (groups.Group, []roles.RoleProvision, error) { + g, err := ToGroups(data) + if err != nil { + return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, err) + } + irps, ok := data["roles_provisioned"].([]interface{}) + if !ok { + return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, errors.New("missing or invalid 'roles_provisioned'")) + } + rps, err := rconsumer.ToRoleProvisions(irps) + if err != nil { + return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDecodeCreateGroupEvent, err) + } + + return g, rps, nil +} + +func decodeUpdateGroupEvent(data map[string]interface{}) (groups.Group, error) { + g, err := ToGroups(data) + if err != nil { + return groups.Group{}, errors.Wrap(errDecodeUpdateGroupEvent, err) + } + return g, nil +} + +func ToGroupStatus(data map[string]interface{}) (groups.Group, error) { + var g groups.Group + id, ok := data["id"].(string) + if !ok { + return groups.Group{}, errID + } + g.ID = id + + stat, ok := data["status"].(string) + if !ok { + return groups.Group{}, errStatus + } + st, err := groups.ToStatus(stat) + if err != nil { + return groups.Group{}, errors.Wrap(errConvertStatus, err) + } + g.Status = st + + uat, ok := data["updated_at"].(string) + if ok { + ut, err := time.Parse(layout, uat) + if err != nil { + return groups.Group{}, errors.Wrap(errUpdatedAt, err) + } + g.UpdatedAt = ut + } + + uby, ok := data["updated_by"].(string) + if ok { + g.UpdatedBy = uby + } + + return g, nil +} + +func decodeChangeStatusGroupEvent(data map[string]interface{}) (groups.Group, error) { + g, err := ToGroupStatus(data) + if err != nil { + return groups.Group{}, errors.Wrap(errDecodeChangeStatusGroupEvent, err) + } + return g, nil +} + +func decodeRemoveGroupEvent(data map[string]interface{}) (groups.Group, error) { + var g groups.Group + id, ok := data["id"].(string) + if !ok { + return groups.Group{}, errors.Wrap(errDecodeRemoveGroupEvent, errID) + } + g.ID = id + + return g, nil +} + +func decodeAddParentGroupEvent(data map[string]interface{}) (id string, parent string, err error) { + id, ok := data["id"].(string) + if !ok { + return "", "", errors.Wrap(errAddParentGroupEvent, errID) + } + + parent, ok = data["parent_id"].(string) + if !ok { + return "", "", errors.Wrap(errDecodeAddParentGroupEvent, errParent) + } + + return id, parent, nil +} + +func decodeRemoveParentGroupEvent(data map[string]interface{}) (id string, err error) { + id, ok := data["id"].(string) + if !ok { + return "", errors.Wrap(errDecodeRemoveParentGroupEvent, errID) + } + + return id, nil +} + +func decodeAddChildrenGroupEvent(data map[string]interface{}) (id string, childrenIDs []string, err error) { + id, ok := data["id"].(string) + if !ok { + return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errID) + } + chIDs, ok := data["children_ids"].([]interface{}) + if !ok { + return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errChildrenIDs) + } + cids, err := rconsumer.ToStrings(chIDs) + if err != nil { + return "", []string{}, errors.Wrap(errDecodeAddChildrenGroupsEvent, errors.Wrap(errChildrenIDs, err)) + } + return id, cids, nil +} + +func decodeRemoveChildrenGroupEvent(data map[string]interface{}) (id string, childrenIDs []string, err error) { + id, ok := data["id"].(string) + if !ok { + return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errID) + } + chIDs, ok := data["children_ids"].([]interface{}) + if !ok { + return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errChildrenIDs) + } + cids, err := rconsumer.ToStrings(chIDs) + if err != nil { + return "", []string{}, errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errors.Wrap(errChildrenIDs, err)) + } + return id, cids, nil +} + +func decodeRemoveAllChildrenGroupEvent(data map[string]interface{}) (id string, err error) { + id, ok := data["id"].(string) + if !ok { + return "", errors.Wrap(errDecodeRemoveChildrenGroupsEvent, errID) + } + + return id, nil +} diff --git a/pkg/groups/events/consumer/doc.go b/pkg/groups/events/consumer/doc.go new file mode 100644 index 0000000000..f3fea76f1e --- /dev/null +++ b/pkg/groups/events/consumer/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package consumer contains events consumer for events +// published by Bootstrap service. +package consumer diff --git a/pkg/groups/events/consumer/streams.go b/pkg/groups/events/consumer/streams.go new file mode 100644 index 0000000000..feccb4121e --- /dev/null +++ b/pkg/groups/events/consumer/streams.go @@ -0,0 +1,253 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "context" + "fmt" + "log/slog" + + "github.com/absmach/supermq/groups" + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/events/store" + rconsumer "github.com/absmach/supermq/pkg/roles/rolemanager/events/consumer" +) + +const ( + stream = "events.supermq.groups" + + create = "group.create" + update = "group.update" + changeStatus = "group.change_status" + remove = "group.remove" + addParentGroup = "group.add_parent_group" + removeParentGroup = "group.remove_parent_group" + addChildrenGroups = "group.add_children_groups" + removeChildrenGroups = "group.remove_children_groups" + removeAllChildrenGroups = "group.remove_all_children_groups" + addRole = "group.role.add" + removeRole = "group.role.remove" + updateRole = "group.role.update" + addRoleActions = "group.role.actions.add" + removeRoleActions = "group.role.actions.remove" + removeAllRoleActions = "group.role.actions.remove_all" + addRoleMembers = "group.role.members.add" + removeRoleMembers = "group.role.members.remove" + removeRoleAllMembers = "group.role.members.remove_all" + removeMemberFromAllRoles = "group.role.members.remove_from_all_roles" +) + +var ( + errNoOperationKey = errors.New("operation key is not found in event message") + errCreateGroupEvent = errors.New("failed to consume group create event") + errUpdateGroupEvent = errors.New("failed to consume group update event") + errChangeStatusGroupEvent = errors.New("failed to consume group change status event") + errRemoveGroupEvent = errors.New("failed to consume group remove event") + errAddParentGroupEvent = errors.New("failed to consume group add parent group event") + errRemoveParentGroupEvent = errors.New("failed to consume group remove parent group event") + errAddChildrenGroupEvent = errors.New("failed to consume group add children groups event") + errRemoveChildrenGroupEvent = errors.New("failed to consume group remove children groups event") + errRemoveAllChildrenGroupEvent = errors.New("failed to consume group remove all children groups event") +) + +type eventHandler struct { + repo groups.Repository + rolesEventHandler rconsumer.EventHandler +} + +func GroupsEventsSubscribe(ctx context.Context, repo groups.Repository, esURL, esConsumerName string, logger *slog.Logger) error { + subscriber, err := store.NewSubscriber(ctx, esURL, logger) + if err != nil { + return err + } + + subConfig := events.SubscriberConfig{ + Stream: stream, + Consumer: esConsumerName, + Handler: NewEventHandler(repo), + Ordered: true, + } + return subscriber.Subscribe(ctx, subConfig) +} + +// NewEventHandler returns new event store handler. +func NewEventHandler(repo groups.Repository) events.EventHandler { + reh := rconsumer.NewEventHandler("group", repo) + return &eventHandler{ + repo: repo, + rolesEventHandler: reh, + } +} + +func (es *eventHandler) Handle(ctx context.Context, event events.Event) error { + msg, err := event.Encode() + if err != nil { + return err + } + + op, ok := msg["operation"] + + if !ok { + return errNoOperationKey + } + switch op { + case create: + return es.createGroupHandler(ctx, msg) + case update: + return es.updateGroupHandler(ctx, msg) + case changeStatus: + return es.changeStatusGroupHandler(ctx, msg) + case remove: + return es.removeGroupHandler(ctx, msg) + case addParentGroup: + return es.addParentGroupHandler(ctx, msg) + case removeParentGroup: + return es.removeParentGroupHandler(ctx, msg) + case addChildrenGroups: + return es.addChildrenGroupsHandler(ctx, msg) + case removeChildrenGroups: + return es.removeChildrenGroupsHandler(ctx, msg) + case removeAllChildrenGroups: + return es.removeAllChildrenGroupsHandler(ctx, msg) + case addRole: + return es.rolesEventHandler.AddEntityRoleHandler(ctx, msg) + case updateRole: + return es.rolesEventHandler.UpdateEntityRoleHandler(ctx, msg) + case removeRole: + return es.rolesEventHandler.RemoveEntityRoleHandler(ctx, msg) + case addRoleActions: + return es.rolesEventHandler.AddEntityRoleActionsHandler(ctx, msg) + case removeRoleActions: + return es.rolesEventHandler.RemoveEntityRoleActionsHandler(ctx, msg) + case removeAllRoleActions: + return es.rolesEventHandler.RemoveAllEntityRoleActionsHandler(ctx, msg) + case addRoleMembers: + return es.rolesEventHandler.AddEntityRoleMembersHandler(ctx, msg) + case removeRoleMembers: + return es.rolesEventHandler.RemoveEntityRoleMembersHandler(ctx, msg) + case removeRoleAllMembers: + return es.rolesEventHandler.RemoveAllEntityRoleMembersHandler(ctx, msg) + case removeMemberFromAllRoles: + return es.rolesEventHandler.RemoveMemberFromAllEntityHandler(ctx, msg) + } + return nil +} + +func (es *eventHandler) createGroupHandler(ctx context.Context, data map[string]interface{}) error { + g, rps, err := decodeCreateGroupEvent(data) + if err != nil { + return errors.Wrap(errCreateGroupEvent, err) + } + + if _, err := es.repo.Save(ctx, g); err != nil { + return errors.Wrap(errCreateGroupEvent, err) + } + if _, err := es.repo.AddRoles(ctx, rps); err != nil { + return errors.Wrap(errCreateGroupEvent, err) + } + + return nil +} + +func (es *eventHandler) updateGroupHandler(ctx context.Context, data map[string]interface{}) error { + g, err := decodeUpdateGroupEvent(data) + if err != nil { + return errors.Wrap(errUpdateGroupEvent, err) + } + + if _, err := es.repo.Update(ctx, g); err != nil { + return errors.Wrap(errUpdateGroupEvent, err) + } + + return nil +} + +func (es *eventHandler) changeStatusGroupHandler(ctx context.Context, data map[string]interface{}) error { + g, err := decodeChangeStatusGroupEvent(data) + if err != nil { + return errors.Wrap(errChangeStatusGroupEvent, err) + } + + if _, err := es.repo.ChangeStatus(ctx, g); err != nil { + return errors.Wrap(errChangeStatusGroupEvent, err) + } + + return nil +} + +func (es *eventHandler) removeGroupHandler(ctx context.Context, data map[string]interface{}) error { + g, err := decodeRemoveGroupEvent(data) + if err != nil { + return errors.Wrap(errRemoveGroupEvent, err) + } + + if err := es.repo.Delete(ctx, g.ID); err != nil { + return errors.Wrap(errRemoveGroupEvent, err) + } + return nil +} + +func (es *eventHandler) addParentGroupHandler(ctx context.Context, data map[string]interface{}) error { + id, parent, err := decodeAddParentGroupEvent(data) + if err != nil { + return errors.Wrap(errAddParentGroupEvent, err) + } + if err := es.repo.AssignParentGroup(ctx, parent, id); err != nil { + return errors.Wrap(errAddParentGroupEvent, err) + } + return nil +} + +func (es *eventHandler) removeParentGroupHandler(ctx context.Context, data map[string]interface{}) error { + id, err := decodeRemoveParentGroupEvent(data) + if err != nil { + return errors.Wrap(errRemoveParentGroupEvent, err) + } + g, err := es.repo.RetrieveByID(ctx, id) + if err != nil { + return errors.Wrap(errRemoveParentGroupEvent, err) + } + fmt.Println(g, g.Parent, g.ID) + if err := es.repo.UnassignParentGroup(ctx, g.Parent, id); err != nil { + return errors.Wrap(errRemoveParentGroupEvent, err) + } + return nil +} + +func (es *eventHandler) addChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error { + id, cids, err := decodeAddChildrenGroupEvent(data) + if err != nil { + return errors.Wrap(errAddChildrenGroupEvent, err) + } + + if err := es.repo.AssignParentGroup(ctx, id, cids...); err != nil { + return errors.Wrap(errAddChildrenGroupEvent, err) + } + return nil +} + +func (es *eventHandler) removeChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error { + id, cids, err := decodeRemoveChildrenGroupEvent(data) + if err != nil { + return errors.Wrap(errRemoveChildrenGroupEvent, err) + } + + if err := es.repo.UnassignParentGroup(ctx, id, cids...); err != nil { + return errors.Wrap(errRemoveChildrenGroupEvent, err) + } + return nil +} + +func (es *eventHandler) removeAllChildrenGroupsHandler(ctx context.Context, data map[string]interface{}) error { + id, err := decodeRemoveAllChildrenGroupEvent(data) + if err != nil { + return errors.Wrap(errRemoveAllChildrenGroupEvent, err) + } + if err := es.repo.UnassignAllChildrenGroups(ctx, id); err != nil && err != repoerr.ErrNotFound { + return errors.Wrap(errRemoveAllChildrenGroupEvent, err) + } + return nil +} diff --git a/pkg/groups/events/doc.go b/pkg/groups/events/doc.go new file mode 100644 index 0000000000..8f09aa3abb --- /dev/null +++ b/pkg/groups/events/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package events provides the events sourcing of groups to +// provide listing in clients and channels concept definitions needed to support +package events diff --git a/pkg/messaging/nats/pubsub.go b/pkg/messaging/nats/pubsub.go index 04f90186ac..a975816bdc 100644 --- a/pkg/messaging/nats/pubsub.go +++ b/pkg/messaging/nats/pubsub.go @@ -94,7 +94,7 @@ func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) return ErrEmptyTopic } - nh := ps.natsHandler(cfg.Handler) + nh := ps.natsHandler(cfg.Handler, cfg.AckErr) consumerConfig := jetstream.ConsumerConfig{ Name: formatConsumerName(cfg.Topic, cfg.ID), @@ -104,6 +104,10 @@ func (ps *pubsub) Subscribe(ctx context.Context, cfg messaging.SubscriberConfig) FilterSubject: cfg.Topic, } + if cfg.Ordered { + consumerConfig.MaxAckPending = 1 + } + switch cfg.DeliveryPolicy { case messaging.DeliverNewPolicy: consumerConfig.DeliverPolicy = jetstream.DeliverNewPolicy @@ -140,17 +144,22 @@ func (ps *pubsub) Unsubscribe(ctx context.Context, id, topic string) error { } } -func (ps *pubsub) natsHandler(h messaging.MessageHandler) func(m jetstream.Msg) { +func (ps *pubsub) natsHandler(h messaging.MessageHandler, ackErr bool) func(m jetstream.Msg) { return func(m jetstream.Msg) { var msg messaging.Message if err := proto.Unmarshal(m.Data(), &msg); err != nil { ps.logger.Warn(fmt.Sprintf("Failed to unmarshal received message: %s", err)) - return } if err := h.Handle(&msg); err != nil { ps.logger.Warn(fmt.Sprintf("Failed to handle SuperMQ message: %s", err)) + if ackErr { + if err := m.Ack(); err != nil { + ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err)) + } + } + return } if err := m.Ack(); err != nil { ps.logger.Warn(fmt.Sprintf("Failed to ack message: %s", err)) diff --git a/pkg/messaging/pubsub.go b/pkg/messaging/pubsub.go index 0c954a886b..393de64fef 100644 --- a/pkg/messaging/pubsub.go +++ b/pkg/messaging/pubsub.go @@ -39,6 +39,8 @@ type SubscriberConfig struct { Topic string Handler MessageHandler DeliveryPolicy DeliveryPolicy + Ordered bool + AckErr bool } // Subscriber specifies message subscription API. diff --git a/pkg/roles/repo/postgres/init.go b/pkg/roles/repo/postgres/init.go index 905205ef65..83af3bdfd3 100644 --- a/pkg/roles/repo/postgres/init.go +++ b/pkg/roles/repo/postgres/init.go @@ -29,24 +29,24 @@ func Migration(rolesTableNamePrefix, entityTableName, entityIDColumnName string) updated_at TIMESTAMP, updated_by VARCHAR(254), created_by VARCHAR(254), - CONSTRAINT unique_role_name_entity_id_constraint UNIQUE ( name, entity_id), - CONSTRAINT fk_entity_id FOREIGN KEY(entity_id) REFERENCES %s(%s) ON DELETE CASCADE - );`, rolesTableNamePrefix, entityTableName, entityIDColumnName), + CONSTRAINT %s_roles_unique_role_name_entity_id_constraint UNIQUE ( name, entity_id), + CONSTRAINT %s_roles_fk_entity_id FOREIGN KEY(entity_id) REFERENCES %s(%s) ON DELETE CASCADE + );`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, entityTableName, entityIDColumnName), fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s_role_actions ( role_id VARCHAR(254) NOT NULL, action VARCHAR(254) NOT NULL, - CONSTRAINT unique_domain_role_action_constraint UNIQUE ( role_id, action), - CONSTRAINT fk_%s_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE + CONSTRAINT %s_role_actions_unique_domain_role_action_constraint UNIQUE ( role_id, action), + CONSTRAINT %s_role_actions_fk_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE - );`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix), + );`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix), fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s_role_members ( role_id VARCHAR(254) NOT NULL, member_id VARCHAR(254) NOT NULL, - CONSTRAINT unique_role_member_constraint UNIQUE (role_id, member_id), - CONSTRAINT fk_%s_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE - );`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix), + CONSTRAINT %s_role_members_unique_role_member_constraint UNIQUE (role_id, member_id), + CONSTRAINT %s_role_members_fk_roles_id FOREIGN KEY(role_id) REFERENCES %s_roles(id) ON DELETE CASCADE + );`, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix, rolesTableNamePrefix), }, Down: []string{ fmt.Sprintf(`DROP TABLE IF EXISTS %s_roles`, rolesTableNamePrefix), diff --git a/pkg/roles/repo/postgres/roles.go b/pkg/roles/repo/postgres/roles.go index 64b8b5e2f0..bef5902efd 100644 --- a/pkg/roles/repo/postgres/roles.go +++ b/pkg/roles/repo/postgres/roles.go @@ -406,7 +406,7 @@ func (repo *Repository) RoleAddActions(ctx context.Context, role roles.Role, act return []string{}, postgres.HandleError(repoerr.ErrCreateEntity, err) } - return repo.RoleListActions(ctx, role.ID) + return actions, nil } func (repo *Repository) RoleListActions(ctx context.Context, roleID string) ([]string, error) { diff --git a/pkg/roles/rolemanager/events/consumer/decode.go b/pkg/roles/rolemanager/events/consumer/decode.go new file mode 100644 index 0000000000..f3b0f5d4d6 --- /dev/null +++ b/pkg/roles/rolemanager/events/consumer/decode.go @@ -0,0 +1,146 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "time" + + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/roles" +) + +var ( + errID = errors.New("missing or invalid 'id'") + errRoleID = errors.New("missing or invalid 'role_id'") + errName = errors.New("missing or invalid 'name'") + errEntityID = errors.New("missing or invalid 'entity_id'") + errActions = errors.New("missing or invalid 'actions'") + errMembers = errors.New("missing or invalid 'members'") + errCreatedAt = errors.New("failed to parse 'created_at' time") + errUpdatedAt = errors.New("failed to parse 'updated_at' time") + errNotString = errors.New("not string type") + + errInvalidRoleProvision = errors.New("invalid 'role_provisions'") + errRoleProvision = errors.New("failed to convert role_provisions interface'") + errRoleProvisionMembers = errors.New("failed to convert role_provisions member interface'") + errRoleProvisionActions = errors.New("failed to convert role_provisions action interface'") +) + +const ( + layout = "2006-01-02T15:04:05.999999Z" +) + +func ToRole(data map[string]interface{}) (roles.Role, error) { + var r roles.Role + + id, ok := data["id"].(string) + if !ok { + return roles.Role{}, errID + } + r.ID = id + + name, ok := data["name"].(string) + if !ok { + return roles.Role{}, errName + } + r.Name = name + + eid, ok := data["entity_id"].(string) + if !ok { + return roles.Role{}, errEntityID + } + r.EntityID = eid + + // Following fields of groups are allowed to be empty. + + cat, ok := data["created_at"].(string) + if ok { + ct, err := time.Parse(layout, cat) + if err != nil { + return roles.Role{}, errors.Wrap(errCreatedAt, err) + } + r.CreatedAt = ct + } + + cby, ok := data["created_by"].(string) + if ok { + r.CreatedBy = cby + } + + uat, ok := data["updated_at"].(string) + if ok { + ut, err := time.Parse(layout, uat) + if err != nil { + return roles.Role{}, errors.Wrap(errUpdatedAt, err) + } + r.UpdatedAt = ut + } + + uby, ok := data["updated_by"].(string) + if ok { + r.UpdatedBy = uby + } + + return r, nil +} + +func ToStrings(data []interface{}) ([]string, error) { + var strs []string + for _, i := range data { + str, ok := i.(string) + if !ok { + return []string{}, errNotString + } + strs = append(strs, str) + } + return strs, nil +} + +func ToRoleProvision(data map[string]interface{}) (roles.RoleProvision, error) { + var rp roles.RoleProvision + + r, err := ToRole(data) + if err != nil { + return roles.RoleProvision{}, err + } + rp.Role = r + + // Following fields of groups are allowed to be empty. + + opActs, ok := data["optional_actions"].([]interface{}) + if ok { + a, err := ToStrings(opActs) + if err != nil { + return roles.RoleProvision{}, errors.Wrap(errRoleProvisionActions, err) + } + rp.OptionalActions = a + } + + opMems, ok := data["optional_members"].([]interface{}) + if ok { + m, err := ToStrings(opMems) + if err != nil { + return roles.RoleProvision{}, errors.Wrap(errRoleProvisionMembers, err) + } + rp.OptionalMembers = m + } + + return rp, nil +} + +func ToRoleProvisions(data []interface{}) ([]roles.RoleProvision, error) { + var rps []roles.RoleProvision + for _, d := range data { + irp, ok := d.(map[string]interface{}) + if !ok { + return []roles.RoleProvision{}, errInvalidRoleProvision + } + rp, err := ToRoleProvision(irp) + if err != nil { + return []roles.RoleProvision{}, errors.Wrap(errRoleProvision, err) + } + rps = append(rps, rp) + } + return rps, nil +} diff --git a/pkg/roles/rolemanager/events/consumer/handler.go b/pkg/roles/rolemanager/events/consumer/handler.go new file mode 100644 index 0000000000..ebb3d8a654 --- /dev/null +++ b/pkg/roles/rolemanager/events/consumer/handler.go @@ -0,0 +1,188 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "context" + "fmt" + + "github.com/absmach/supermq/pkg/errors" + repoerr "github.com/absmach/supermq/pkg/errors/repository" + "github.com/absmach/supermq/pkg/roles" +) + +const ( + errAddEntityRoleEvent = "failed to consume %s add role event : %w" + errUpdateEntityRoleEvent = "failed to consume %s update role event : %w" + errRemoveEntityRoleEvent = "failed to consume %s remove role event : %w" + errAddEntityRoleActionsEvent = "failed to consume %s add role actions event : %w" + errRemoveEntityRoleActionsEvent = "failed to consume %s remove role actions event : %w" + errRemoveEntityRoleAllActionsEvent = "failed to consume %s remove role all actions event : %w" + errAddEntityRoleMembersEvent = "failed to consume %s add role members event : %w" + errRemoveEntityRoleMembersEvent = "failed to consume %s remove role members event : %w" + errRemoveEntityRoleAllMembersEvent = "failed to consume %s remove role all members event : %w" +) + +type EventHandler struct { + entityType string + repo roles.Repository +} + +func NewEventHandler(entityType string, repo roles.Repository) EventHandler { + return EventHandler{ + entityType: entityType, + repo: repo, + } +} + +func (es *EventHandler) AddEntityRoleHandler(ctx context.Context, data map[string]interface{}) error { + rps, err := ToRoleProvision(data) + if err != nil { + return fmt.Errorf(errAddEntityRoleEvent, es.entityType, err) + } + if _, err := es.repo.AddRoles(ctx, []roles.RoleProvision{rps}); err != nil { + if !errors.Contains(err, repoerr.ErrConflict) { + return fmt.Errorf(errAddEntityRoleEvent, es.entityType, err) + } + } + + return nil +} + +func (es *EventHandler) UpdateEntityRoleHandler(ctx context.Context, data map[string]interface{}) error { + ro, err := ToRole(data) + if err != nil { + return fmt.Errorf(errUpdateEntityRoleEvent, es.entityType, err) + } + + if _, err = es.repo.UpdateRole(ctx, ro); err != nil { + return fmt.Errorf(errUpdateEntityRoleEvent, es.entityType, err) + } + + return nil +} + +func (es *EventHandler) RemoveEntityRoleHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errRemoveEntityRoleEvent, es.entityType, errRoleID) + } + + if err := es.repo.RemoveRoles(ctx, []string{id}); err != nil { + return fmt.Errorf(errRemoveEntityRoleEvent, es.entityType, err) + } + + return nil +} + +func (es *EventHandler) AddEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errRoleID) + } + iacts, ok := data["actions"].([]interface{}) + if !ok { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errActions) + } + acts, err := ToStrings(iacts) + if err != nil { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err) + } + + if _, err := es.repo.RoleAddActions(ctx, roles.Role{ID: id}, acts); err != nil { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err) + } + + return nil +} + +func (es *EventHandler) RemoveEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errRoleID) + } + iacts, ok := data["actions"].([]interface{}) + if !ok { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, errActions) + } + acts, err := ToStrings(iacts) + if err != nil { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err) + } + + if err := es.repo.RoleRemoveActions(ctx, roles.Role{ID: id}, acts); err != nil { + return fmt.Errorf(errAddEntityRoleActionsEvent, es.entityType, err) + } + return nil +} + +func (es *EventHandler) RemoveAllEntityRoleActionsHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errRemoveEntityRoleAllActionsEvent, es.entityType, errRoleID) + } + + if err := es.repo.RoleRemoveAllActions(ctx, roles.Role{ID: id}); err != nil { + return fmt.Errorf(errRemoveEntityRoleAllActionsEvent, es.entityType, err) + } + return nil +} + +func (es *EventHandler) AddEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, errRoleID) + } + imems, ok := data["members"].([]interface{}) + if !ok { + return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, errMembers) + } + mems, err := ToStrings(imems) + if err != nil { + return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, err) + } + + if _, err := es.repo.RoleAddMembers(ctx, roles.Role{ID: id}, mems); err != nil { + return fmt.Errorf(errAddEntityRoleMembersEvent, es.entityType, err) + } + + return nil +} + +func (es *EventHandler) RemoveEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, errRoleID) + } + imems, ok := data["members"].([]interface{}) + if !ok { + return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, errMembers) + } + mems, err := ToStrings(imems) + if err != nil { + return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, err) + } + + if err := es.repo.RoleRemoveMembers(ctx, roles.Role{ID: id}, mems); err != nil { + return fmt.Errorf(errRemoveEntityRoleMembersEvent, es.entityType, err) + } + + return nil +} + +func (es *EventHandler) RemoveAllEntityRoleMembersHandler(ctx context.Context, data map[string]interface{}) error { + id, ok := data["role_id"].(string) + if !ok { + return fmt.Errorf(errRemoveEntityRoleAllMembersEvent, es.entityType, errRoleID) + } + + if err := es.repo.RoleRemoveAllMembers(ctx, roles.Role{ID: id}); err != nil { + return fmt.Errorf(errRemoveEntityRoleAllMembersEvent, es.entityType, err) + } + return nil +} + +func (es *EventHandler) RemoveMemberFromAllEntityHandler(ctx context.Context, data map[string]interface{}) error { + return nil +} diff --git a/pkg/roles/rolemanager/events/streams.go b/pkg/roles/rolemanager/events/streams.go index 613799b76c..94c06a84e6 100644 --- a/pkg/roles/rolemanager/events/streams.go +++ b/pkg/roles/rolemanager/events/streams.go @@ -24,9 +24,10 @@ type RoleManagerEventStore struct { // events to event store. func NewRoleManagerEventStore(svcName, operationPrefix string, svc roles.RoleManager, publisher events.Publisher) RoleManagerEventStore { return RoleManagerEventStore{ - svcName: svcName, - svc: svc, - Publisher: publisher, + svcName: svcName, + operationPrefix: operationPrefix, + svc: svc, + Publisher: publisher, } } diff --git a/pkg/sdk/channels_test.go b/pkg/sdk/channels_test.go index 4cad6ef9d1..a30fd97f17 100644 --- a/pkg/sdk/channels_test.go +++ b/pkg/sdk/channels_test.go @@ -392,9 +392,11 @@ func TestListChannels(t *testing.T) { offset: offset, total: total, channelsPageMeta: channels.PageMetadata{ - Offset: offset, - Limit: limit, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: offset, + Limit: limit, }, svcRes: channels.Page{ PageMetadata: channels.PageMetadata{ @@ -417,8 +419,11 @@ func TestListChannels(t *testing.T) { offset: offset, limit: limit, channelsPageMeta: channels.PageMetadata{ - Offset: offset, - Limit: limit, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: offset, + Limit: limit, }, svcRes: channels.Page{}, authenticateErr: svcerr.ErrAuthentication, @@ -426,16 +431,20 @@ func TestListChannels(t *testing.T) { err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), }, { - desc: "list channels with empty token", - token: "", - domainID: validID, - offset: offset, - limit: limit, - channelsPageMeta: channels.PageMetadata{}, - svcRes: channels.Page{}, - svcErr: nil, - response: sdk.ChannelsPage{}, - err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), + desc: "list channels with empty token", + token: "", + domainID: validID, + offset: offset, + limit: limit, + channelsPageMeta: channels.PageMetadata{ + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + }, + svcRes: channels.Page{}, + svcErr: nil, + response: sdk.ChannelsPage{}, + err: errors.NewSDKErrorWithStatus(apiutil.ErrBearerToken, http.StatusUnauthorized), }, { desc: "list channels with zero limit", @@ -444,9 +453,11 @@ func TestListChannels(t *testing.T) { offset: offset, limit: 0, channelsPageMeta: channels.PageMetadata{ - Offset: offset, - Limit: 10, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: offset, + Limit: 10, }, svcRes: channels.Page{ PageMetadata: channels.PageMetadata{ @@ -464,16 +475,20 @@ func TestListChannels(t *testing.T) { err: nil, }, { - desc: "list channels with limit greater than max", - token: validToken, - domainID: domainID, - offset: offset, - limit: 110, - channelsPageMeta: channels.PageMetadata{}, - svcRes: channels.Page{}, - svcErr: nil, - response: sdk.ChannelsPage{}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), + desc: "list channels with limit greater than max", + token: validToken, + domainID: domainID, + offset: offset, + limit: 110, + channelsPageMeta: channels.PageMetadata{ + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + }, + svcRes: channels.Page{}, + svcErr: nil, + response: sdk.ChannelsPage{}, + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), }, { desc: "list channels with level", @@ -483,9 +498,11 @@ func TestListChannels(t *testing.T) { limit: 1, level: 1, channelsPageMeta: channels.PageMetadata{ - Offset: offset, - Limit: 1, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: offset, + Limit: 1, }, svcRes: channels.Page{ PageMetadata: channels.PageMetadata{ @@ -510,10 +527,12 @@ func TestListChannels(t *testing.T) { limit: 10, metadata: sdk.Metadata{"name": "client_89"}, channelsPageMeta: channels.PageMetadata{ - Offset: offset, - Limit: 10, - Permission: defPermission, - Metadata: clients.Metadata{"name": "client_89"}, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: offset, + Limit: 10, + Metadata: clients.Metadata{"name": "client_89"}, }, svcRes: channels.Page{ PageMetadata: channels.PageMetadata{ @@ -539,11 +558,15 @@ func TestListChannels(t *testing.T) { metadata: sdk.Metadata{ "test": make(chan int), }, - channelsPageMeta: channels.PageMetadata{}, - svcRes: channels.Page{}, - svcErr: nil, - response: sdk.ChannelsPage{}, - err: errors.NewSDKError(errors.New("json: unsupported type: chan int")), + channelsPageMeta: channels.PageMetadata{ + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + }, + svcRes: channels.Page{}, + svcErr: nil, + response: sdk.ChannelsPage{}, + err: errors.NewSDKError(errors.New("json: unsupported type: chan int")), }, { desc: "list channels with service response that can't be unmarshalled", @@ -552,9 +575,11 @@ func TestListChannels(t *testing.T) { offset: 0, limit: 10, channelsPageMeta: channels.PageMetadata{ - Offset: 0, - Limit: 10, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: 0, + Limit: 10, }, svcRes: channels.Page{ PageMetadata: channels.PageMetadata{ diff --git a/pkg/sdk/clients.go b/pkg/sdk/clients.go index c52e6ecc09..c3f4c971a8 100644 --- a/pkg/sdk/clients.go +++ b/pkg/sdk/clients.go @@ -252,23 +252,6 @@ func (sdk mgSDK) DeleteClient(id, domainID, token string) errors.SDKError { return sdkerr } -func (sdk mgSDK) ListUserClients(userID, domainID string, pm PageMetadata, token string) (ClientsPage, errors.SDKError) { - url, err := sdk.withQueryParams(sdk.clientsURL, fmt.Sprintf("%s/%s/%s/%s", domainID, usersEndpoint, userID, clientsEndpoint), pm) - if err != nil { - return ClientsPage{}, errors.NewSDKError(err) - } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, token, nil, nil, http.StatusOK) - if sdkerr != nil { - return ClientsPage{}, sdkerr - } - cp := ClientsPage{} - if err := json.Unmarshal(body, &cp); err != nil { - return ClientsPage{}, errors.NewSDKError(err) - } - - return cp, nil -} - func (sdk mgSDK) CreateClientRole(id, domainID string, rq RoleReq, token string) (Role, errors.SDKError) { return sdk.createRole(sdk.clientsURL, clientsEndpoint, id, domainID, rq, token) } diff --git a/pkg/sdk/clients_test.go b/pkg/sdk/clients_test.go index c8bdbafcab..e702444222 100644 --- a/pkg/sdk/clients_test.go +++ b/pkg/sdk/clients_test.go @@ -357,9 +357,11 @@ func TestListClients(t *testing.T) { Limit: 100, }, svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: 0, + Limit: 100, }, svcRes: clients.ClientsPage{ Page: clients.Page{ @@ -387,9 +389,11 @@ func TestListClients(t *testing.T) { Limit: 100, }, svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: 0, + Limit: 100, }, svcRes: clients.ClientsPage{}, authenticateErr: svcerr.ErrAuthentication, @@ -404,7 +408,11 @@ func TestListClients(t *testing.T) { Offset: 0, Limit: 1000, }, - svcReq: clients.Page{}, + svcReq: clients.Page{ + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + }, svcRes: clients.ClientsPage{}, svcErr: nil, response: sdk.ClientsPage{}, @@ -419,7 +427,11 @@ func TestListClients(t *testing.T) { Limit: 100, Name: strings.Repeat("a", 1025), }, - svcReq: clients.Page{}, + svcReq: clients.Page{ + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + }, svcRes: clients.ClientsPage{}, svcErr: nil, response: sdk.ClientsPage{}, @@ -435,10 +447,12 @@ func TestListClients(t *testing.T) { Status: clients.DisabledStatus.String(), }, svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - Status: clients.DisabledStatus, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: 0, + Limit: 100, + Status: clients.DisabledStatus, }, svcRes: clients.ClientsPage{ Page: clients.Page{ @@ -468,10 +482,12 @@ func TestListClients(t *testing.T) { Tag: "tag1", }, svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - Tag: "tag1", + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: 0, + Limit: 100, + Tag: "tag1", }, svcRes: clients.ClientsPage{ Page: clients.Page{ @@ -502,7 +518,11 @@ func TestListClients(t *testing.T) { "test": make(chan int), }, }, - svcReq: clients.Page{}, + svcReq: clients.Page{ + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + }, svcRes: clients.ClientsPage{}, svcErr: nil, response: sdk.ClientsPage{}, @@ -517,9 +537,11 @@ func TestListClients(t *testing.T) { Limit: 100, }, svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, + Actions: []string{}, + Order: "updated_at", + Dir: "asc", + Offset: 0, + Limit: 100, }, svcRes: clients.ClientsPage{ Page: clients.Page{ @@ -547,12 +569,12 @@ func TestListClients(t *testing.T) { tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} } authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr) - svcCall := tsvc.On("ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq).Return(tc.svcRes, tc.svcErr) + svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.svcReq).Return(tc.svcRes, tc.svcErr) resp, err := mgsdk.Clients(tc.pageMeta, tc.domainID, tc.token) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, mock.Anything, tc.svcReq) + ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, tc.svcReq) assert.True(t, ok) } svcCall.Unset() @@ -1400,259 +1422,6 @@ func TestDeleteClient(t *testing.T) { } } -func TestListUserClients(t *testing.T) { - ts, tsvc, auth := setupClients() - defer ts.Close() - - var sdkClients []sdk.Client - for i := 10; i < 100; i++ { - c := generateTestClient(t) - if i == 50 { - c.Status = clients.DisabledStatus.String() - c.Tags = []string{"tag1", "tag2"} - } - sdkClients = append(sdkClients, c) - } - - conf := sdk.Config{ - ClientsURL: ts.URL, - } - mgsdk := sdk.NewSDK(conf) - - cases := []struct { - desc string - token string - session smqauthn.Session - userID string - domainID string - pageMeta sdk.PageMetadata - svcReq clients.Page - svcRes clients.ClientsPage - svcErr error - authenticateErr error - response sdk.ClientsPage - err errors.SDKError - }{ - { - desc: "list user clients successfully", - token: validToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - }, - svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - }, - svcRes: clients.ClientsPage{ - Page: clients.Page{ - Offset: 0, - Limit: 100, - Total: uint64(len(sdkClients)), - }, - Clients: convertClients(sdkClients...), - }, - svcErr: nil, - response: sdk.ClientsPage{ - PageRes: sdk.PageRes{ - Limit: 100, - Total: uint64(len(sdkClients)), - }, - Clients: sdkClients, - }, - }, - { - desc: "list user clients with an invalid token", - token: invalidToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - }, - svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - }, - svcRes: clients.ClientsPage{}, - authenticateErr: svcerr.ErrAuthentication, - response: sdk.ClientsPage{}, - err: errors.NewSDKErrorWithStatus(svcerr.ErrAuthentication, http.StatusUnauthorized), - }, - { - desc: "list user clients with limit greater than max", - token: validToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 1000, - }, - svcReq: clients.Page{}, - svcRes: clients.ClientsPage{}, - svcErr: nil, - response: sdk.ClientsPage{}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), - }, - { - desc: "list user clients with name size greater than max", - token: validToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - Name: strings.Repeat("a", 1025), - }, - svcReq: clients.Page{}, - svcRes: clients.ClientsPage{}, - svcErr: nil, - response: sdk.ClientsPage{}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest), - }, - { - desc: "list user clients with status", - token: validToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - Status: clients.DisabledStatus.String(), - }, - svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - Status: clients.DisabledStatus, - }, - svcRes: clients.ClientsPage{ - Page: clients.Page{ - Offset: 0, - Limit: 100, - Total: 1, - }, - Clients: convertClients(sdkClients[50]), - }, - svcErr: nil, - response: sdk.ClientsPage{ - PageRes: sdk.PageRes{ - Limit: 100, - Total: 1, - }, - Clients: []sdk.Client{sdkClients[50]}, - }, - err: nil, - }, - { - desc: "list user clients with tags", - token: validToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - Tag: "tag1", - }, - svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - Tag: "tag1", - }, - svcRes: clients.ClientsPage{ - Page: clients.Page{ - Offset: 0, - Limit: 100, - Total: 1, - }, - Clients: convertClients(sdkClients[50]), - }, - svcErr: nil, - response: sdk.ClientsPage{ - PageRes: sdk.PageRes{ - Limit: 100, - Total: 1, - }, - Clients: []sdk.Client{sdkClients[50]}, - }, - err: nil, - }, - { - desc: "list user clients with invalid metadata", - token: validToken, - userID: validID, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - Metadata: map[string]interface{}{ - "test": make(chan int), - }, - }, - svcReq: clients.Page{}, - svcRes: clients.ClientsPage{}, - svcErr: nil, - response: sdk.ClientsPage{}, - err: errors.NewSDKError(errors.New("json: unsupported type: chan int")), - }, - { - desc: "list user clients with response that can't be unmarshalled", - token: validToken, - domainID: domainID, - pageMeta: sdk.PageMetadata{ - Offset: 0, - Limit: 100, - }, - svcReq: clients.Page{ - Offset: 0, - Limit: 100, - Permission: defPermission, - }, - svcRes: clients.ClientsPage{ - Page: clients.Page{ - Offset: 0, - Limit: 100, - Total: 1, - }, - Clients: []clients.Client{{ - Name: sdkClients[0].Name, - Tags: sdkClients[0].Tags, - Credentials: clients.Credentials(sdkClients[0].Credentials), - Metadata: clients.Metadata{ - "test": make(chan int), - }, - }}, - }, - svcErr: nil, - response: sdk.ClientsPage{}, - err: errors.NewSDKError(errors.New("unexpected end of JSON input")), - }, - } - for _, tc := range cases { - t.Run(tc.desc, func(t *testing.T) { - if tc.token == validToken { - tc.session = smqauthn.Session{DomainUserID: domainID + "_" + validID, UserID: validID, DomainID: domainID} - } - authCall := auth.On("Authenticate", mock.Anything, mock.Anything).Return(tc.session, tc.authenticateErr) - svcCall := tsvc.On("ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq).Return(tc.svcRes, tc.svcErr) - resp, err := mgsdk.ListUserClients(tc.userID, tc.domainID, tc.pageMeta, tc.token) - assert.Equal(t, tc.err, err) - assert.Equal(t, tc.response, resp) - if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "ListClients", mock.Anything, tc.session, tc.userID, tc.svcReq) - assert.True(t, ok) - } - svcCall.Unset() - authCall.Unset() - }) - } -} - func TestSetClientParent(t *testing.T) { ts, csvc, auth := setupClients() defer ts.Close() diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index 5eb948a124..e4229814fd 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -4393,67 +4393,6 @@ func (_c *SDK_ListDomainUsers_Call) RunAndReturn(run func(string, sdk.PageMetada return _c } -// ListUserClients provides a mock function with given fields: userID, domainID, pm, token -func (_m *SDK) ListUserClients(userID string, domainID string, pm sdk.PageMetadata, token string) (sdk.ClientsPage, errors.SDKError) { - ret := _m.Called(userID, domainID, pm, token) - - if len(ret) == 0 { - panic("no return value specified for ListUserClients") - } - - var r0 sdk.ClientsPage - var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, sdk.PageMetadata, string) (sdk.ClientsPage, errors.SDKError)); ok { - return rf(userID, domainID, pm, token) - } - if rf, ok := ret.Get(0).(func(string, string, sdk.PageMetadata, string) sdk.ClientsPage); ok { - r0 = rf(userID, domainID, pm, token) - } else { - r0 = ret.Get(0).(sdk.ClientsPage) - } - - if rf, ok := ret.Get(1).(func(string, string, sdk.PageMetadata, string) errors.SDKError); ok { - r1 = rf(userID, domainID, pm, token) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(errors.SDKError) - } - } - - return r0, r1 -} - -// SDK_ListUserClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListUserClients' -type SDK_ListUserClients_Call struct { - *mock.Call -} - -// ListUserClients is a helper method to define mock.On call -// - userID string -// - domainID string -// - pm sdk.PageMetadata -// - token string -func (_e *SDK_Expecter) ListUserClients(userID interface{}, domainID interface{}, pm interface{}, token interface{}) *SDK_ListUserClients_Call { - return &SDK_ListUserClients_Call{Call: _e.mock.On("ListUserClients", userID, domainID, pm, token)} -} - -func (_c *SDK_ListUserClients_Call) Run(run func(userID string, domainID string, pm sdk.PageMetadata, token string)) *SDK_ListUserClients_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(sdk.PageMetadata), args[3].(string)) - }) - return _c -} - -func (_c *SDK_ListUserClients_Call) Return(_a0 sdk.ClientsPage, _a1 errors.SDKError) *SDK_ListUserClients_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *SDK_ListUserClients_Call) RunAndReturn(run func(string, string, sdk.PageMetadata, string) (sdk.ClientsPage, errors.SDKError)) *SDK_ListUserClients_Call { - _c.Call.Return(run) - return _c -} - // Members provides a mock function with given fields: groupID, domainID, pm, token func (_m *SDK) Members(groupID string, domainID string, pm sdk.PageMetadata, token string) (sdk.UsersPage, errors.SDKError) { ret := _m.Called(groupID, domainID, pm, token) diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index 445ed142aa..f16dd92385 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -515,17 +515,6 @@ type SDK interface { // fmt.Println(err) RemoveClientParent(id, domainID, groupID, token string) errors.SDKError - // ListUserClients returns list of clients for the given user ID and filters. - // - // example: - // pm := sdk.PageMetadata{ - // Offset: 0, - // Limit: 10, - // } - // clients, _ := sdk.ListUserClients("userID", "domainID", pm,"token") - // fmt.Println(clients) - ListUserClients(userID, domainID string, pm PageMetadata, token string) (ClientsPage, errors.SDKError) - // CreateClientRole creates new client role and returns its id. // // example: diff --git a/pkg/sdk/setup_test.go b/pkg/sdk/setup_test.go index c55789dfe4..4d3ba83a4c 100644 --- a/pkg/sdk/setup_test.go +++ b/pkg/sdk/setup_test.go @@ -216,7 +216,6 @@ func convertChannel(g sdk.Channel) mgchannels.Channel { UpdatedAt: g.UpdatedAt, UpdatedBy: g.UpdatedBy, Status: status, - Permissions: g.Permissions, } }