diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index a4622e1b70..05c3f4bf8c 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -535,6 +535,7 @@ func main() { privateAPI.GET("/collections", collectionHandler.Get) privateAPI.GET("/collections/v2", collectionHandler.GetV2) privateAPI.POST("/collections/share", collectionHandler.Share) + privateAPI.POST("/collections/join-link", collectionHandler.JoinLink) privateAPI.POST("/collections/share-url", collectionHandler.ShareURL) privateAPI.PUT("/collections/share-url", collectionHandler.UpdateShareURL) privateAPI.DELETE("/collections/share-url/:collectionID", collectionHandler.UnShareURL) diff --git a/server/ente/collection.go b/server/ente/collection.go index 71b4c50ac2..9b9a0759d9 100644 --- a/server/ente/collection.go +++ b/server/ente/collection.go @@ -97,6 +97,11 @@ type AlterShareRequest struct { Role *CollectionParticipantRole `json:"role"` } +type JoinCollectionViaLinkRequest struct { + CollectionID int64 `json:"collectionID" binding:"required"` + EncryptedKey string `json:"encryptedKey" binding:"required"` +} + // AddFilesRequest represents a request to add files to a collection type AddFilesRequest struct { CollectionID int64 `json:"collectionID" binding:"required"` diff --git a/server/ente/public_collection.go b/server/ente/public_collection.go index 6e71e35b13..7a477c4365 100644 --- a/server/ente/public_collection.go +++ b/server/ente/public_collection.go @@ -3,7 +3,7 @@ package ente import ( "database/sql/driver" "encoding/json" - + "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" ) @@ -11,8 +11,10 @@ import ( type CreatePublicAccessTokenRequest struct { CollectionID int64 `json:"collectionID" binding:"required"` EnableCollect bool `json:"enableCollect"` - ValidTill int64 `json:"validTill"` - DeviceLimit int `json:"deviceLimit"` + // defaults to true + EnableJoin *bool `json:"joinViaLink"` + ValidTill int64 `json:"validTill"` + DeviceLimit int `json:"deviceLimit"` } type UpdatePublicAccessTokenRequest struct { @@ -26,6 +28,7 @@ type UpdatePublicAccessTokenRequest struct { EnableDownload *bool `json:"enableDownload"` EnableCollect *bool `json:"enableCollect"` DisablePassword *bool `json:"disablePassword"` + EnableJoin *bool `json:"enableJoin"` } type VerifyPasswordRequest struct { @@ -50,6 +53,23 @@ type PublicCollectionToken struct { OpsLimit *int64 EnableDownload bool EnableCollect bool + EnableJoin bool +} + +func (p PublicCollectionToken) CanJoin() error { + if p.IsDisabled { + return NewBadRequestWithMessage("link disabled") + } + if p.ValidTill > 0 && p.ValidTill < time.Microseconds() { + return NewBadRequestWithMessage("token expired") + } + if !p.EnableDownload { + return NewBadRequestWithMessage("can not join as download is disabled") + } + if !p.EnableJoin { + return NewBadRequestWithMessage("can not join as join is disabled") + } + return nil } // PublicURL represents information about non-disabled public url for a collection @@ -62,9 +82,10 @@ type PublicURL struct { EnableCollect bool `json:"enableCollect"` PasswordEnabled bool `json:"passwordEnabled"` // Nonce contains the nonce value for the password if the link is password protected. - Nonce *string `json:"nonce,omitempty"` - MemLimit *int64 `json:"memLimit,omitempty"` - OpsLimit *int64 `json:"opsLimit,omitempty"` + Nonce *string `json:"nonce,omitempty"` + MemLimit *int64 `json:"memLimit,omitempty"` + OpsLimit *int64 `json:"opsLimit,omitempty"` + EnableJoin bool `json:"enableJoinViaLink"` } type PublicAccessContext struct { diff --git a/server/migrations/95_public_collection_join_album.down.sql b/server/migrations/95_public_collection_join_album.down.sql new file mode 100644 index 0000000000..134128116f --- /dev/null +++ b/server/migrations/95_public_collection_join_album.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE public_collection_tokens + DROP COLUMN IF EXISTS enable_join; diff --git a/server/migrations/95_public_collection_join_album.up.sql b/server/migrations/95_public_collection_join_album.up.sql new file mode 100644 index 0000000000..d5b73572d5 --- /dev/null +++ b/server/migrations/95_public_collection_join_album.up.sql @@ -0,0 +1,9 @@ +BEGIN; +ALTER table public_collection_tokens + ADD COLUMN IF NOT EXISTS enable_join bool DEFAULT TRUE; + +UPDATE public_collection_tokens SET enable_join = FALSE; + +ALTER TABLE public_collection_tokens ALTER COLUMN enable_join SET NOT NULL; + +COMMIT; diff --git a/server/pkg/api/collection.go b/server/pkg/api/collection.go index 94918bcf9d..57cc8c4f28 100644 --- a/server/pkg/api/collection.go +++ b/server/pkg/api/collection.go @@ -122,6 +122,20 @@ func (h *CollectionHandler) Share(c *gin.Context) { }) } +func (h *CollectionHandler) JoinLink(c *gin.Context) { + var request ente.JoinCollectionViaLinkRequest + if err := c.ShouldBindJSON(&request); err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + err := h.Controller.JoinViaLink(c, request) + if err != nil { + handler.Error(c, stacktrace.Propagate(err, "")) + return + } + c.JSON(http.StatusOK, gin.H{}) +} + // ShareURL generates a publicly sharable url func (h *CollectionHandler) ShareURL(c *gin.Context) { var request ente.CreatePublicAccessTokenRequest diff --git a/server/pkg/controller/collection.go b/server/pkg/controller/collection.go index dca53f80c3..7476de1c39 100644 --- a/server/pkg/controller/collection.go +++ b/server/pkg/controller/collection.go @@ -181,6 +181,47 @@ func (c *CollectionController) Share(ctx *gin.Context, req ente.AlterShareReques return sharees, nil } +func (c *CollectionController) JoinViaLink(ctx *gin.Context, req ente.JoinCollectionViaLinkRequest) error { + userID := auth.GetUserID(ctx.Request.Header) + collection, err := c.CollectionRepo.Get(req.CollectionID) + if err != nil { + return stacktrace.Propagate(err, "") + } + if collection.Owner.ID == userID { + return stacktrace.Propagate(ente.ErrBadRequest, "owner can not join via link") + } + if !collection.AllowSharing() { + return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("joining %s is not allowed", collection.Type)) + } + publicCollectionToken, err := c.PublicCollectionCtrl.GetActivePublicCollectionToken(ctx, req.CollectionID) + if err != nil { + return stacktrace.Propagate(err, "") + } + + if canJoin := publicCollectionToken.CanJoin(); canJoin != nil { + return stacktrace.Propagate(ente.ErrBadRequest, fmt.Sprintf("can not join collection: %s", canJoin.Error())) + } + accessToken := auth.GetAccessToken(ctx) + if publicCollectionToken.Token != accessToken { + return stacktrace.Propagate(ente.ErrPermissionDenied, "token doesn't match collection") + } + if publicCollectionToken.PassHash != nil && *publicCollectionToken.PassHash != "" { + accessTokenJWT := auth.GetAccessTokenJWT(ctx) + if passCheckErr := c.PublicCollectionCtrl.ValidateJWTToken(ctx, accessTokenJWT, *publicCollectionToken.PassHash); passCheckErr != nil { + return stacktrace.Propagate(passCheckErr, "") + } + } + err = c.BillingCtrl.HasActiveSelfOrFamilySubscription(collection.Owner.ID, true) + if err != nil { + return stacktrace.Propagate(err, "") + } + role := ente.VIEWER + if publicCollectionToken.EnableCollect { + role = ente.COLLABORATOR + } + return c.CollectionRepo.Share(req.CollectionID, collection.Owner.ID, userID, req.EncryptedKey, role, time.Microseconds()) +} + // UnShare unshares a collection with a user func (c *CollectionController) UnShare(ctx *gin.Context, cID int64, fromUserID int64, toUserEmail string) ([]ente.CollectionUser, error) { toUserID, err := c.UserRepo.GetUserIDWithEmail(toUserEmail) diff --git a/server/pkg/controller/public_collection.go b/server/pkg/controller/public_collection.go index d526837ddd..d9b5e29b77 100644 --- a/server/pkg/controller/public_collection.go +++ b/server/pkg/controller/public_collection.go @@ -61,7 +61,8 @@ type PublicCollectionController struct { func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req ente.CreatePublicAccessTokenRequest) (ente.PublicURL, error) { accessToken := shortuuid.New()[0:AccessTokenLength] - err := c.PublicCollectionRepo.Insert(ctx, req.CollectionID, accessToken, req.ValidTill, req.DeviceLimit, req.EnableCollect) + err := c.PublicCollectionRepo. + Insert(ctx, req.CollectionID, accessToken, req.ValidTill, req.DeviceLimit, req.EnableCollect, req.EnableJoin) if err != nil { if errors.Is(err, ente.ErrActiveLinkAlreadyExists) { collectionToPubUrlMap, err2 := c.PublicCollectionRepo.GetCollectionToActivePublicURLMap(ctx, []int64{req.CollectionID}) @@ -90,6 +91,10 @@ func (c *PublicCollectionController) CreateAccessToken(ctx context.Context, req return response, nil } +func (c *PublicCollectionController) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) { + return c.PublicCollectionRepo.GetActivePublicCollectionToken(ctx, collectionID) +} + func (c *PublicCollectionController) CreateFile(ctx *gin.Context, file ente.File, app ente.App) (ente.File, error) { collection, err := c.GetPublicCollection(ctx, true) if err != nil { @@ -146,6 +151,9 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en if req.EnableCollect != nil { publicCollectionToken.EnableCollect = *req.EnableCollect } + if req.EnableJoin != nil { + publicCollectionToken.EnableJoin = *req.EnableJoin + } err = c.PublicCollectionRepo.UpdatePublicCollectionToken(ctx, publicCollectionToken) if err != nil { return ente.PublicURL{}, stacktrace.Propagate(err, "") @@ -156,6 +164,7 @@ func (c *PublicCollectionController) UpdateSharedUrl(ctx context.Context, req en ValidTill: publicCollectionToken.ValidTill, EnableDownload: publicCollectionToken.EnableDownload, EnableCollect: publicCollectionToken.EnableCollect, + EnableJoin: publicCollectionToken.EnableJoin, PasswordEnabled: publicCollectionToken.PassHash != nil && *publicCollectionToken.PassHash != "", Nonce: publicCollectionToken.Nonce, MemLimit: publicCollectionToken.MemLimit, diff --git a/server/pkg/repo/collection.go b/server/pkg/repo/collection.go index bb18750116..f7ddcfeced 100644 --- a/server/pkg/repo/collection.go +++ b/server/pkg/repo/collection.go @@ -155,7 +155,7 @@ func (repo *CollectionRepository) GetCollectionsOwnedByUserV2(userID int64, upda SELECT c.collection_id, c.owner_id, c.encrypted_key,c.key_decryption_nonce, c.name, c.encrypted_name, c.name_decryption_nonce, c.type, c.app, c.attributes, c.updation_time, c.is_deleted, c.magic_metadata, c.pub_magic_metadata, users.user_id, users.encrypted_email, users.email_decryption_nonce, cs.role_type, -pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_at, pct.pw_hash, pct.pw_nonce, pct.mem_limit, pct.ops_limit, pct.enable_download, pct.enable_collect +pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_at, pct.pw_hash, pct.pw_nonce, pct.mem_limit, pct.ops_limit, pct.enable_download, pct.enable_collect, pct.enable_join FROM collections c LEFT JOIN collection_shares cs ON (cs.collection_id = c.collection_id AND cs.is_deleted = false) @@ -175,14 +175,14 @@ pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_ var c ente.Collection var name, encryptedName, nameDecryptionNonce sql.NullString var pctDeviceLimit sql.NullInt32 - var pctEnableDownload, pctEnableCollect sql.NullBool + var pctEnableDownload, pctEnableCollect, pctEnableJoin sql.NullBool var shareUserID, pctValidTill, pctCreatedAt, pctUpdatedAt, pctMemLimit, pctOpsLimit sql.NullInt64 var encryptedEmail, nonce []byte var shareeRoleType, pctToken, pctPwHash, pctPwNonce sql.NullString if err := rows.Scan(&c.ID, &c.Owner.ID, &c.EncryptedKey, &c.KeyDecryptionNonce, &name, &encryptedName, &nameDecryptionNonce, &c.Type, &c.App, &c.Attributes, &c.UpdationTime, &c.IsDeleted, &c.MagicMetadata, &c.PublicMagicMetadata, &shareUserID, &encryptedEmail, &nonce, &shareeRoleType, - &pctToken, &pctValidTill, &pctDeviceLimit, &pctCreatedAt, &pctUpdatedAt, &pctPwHash, &pctPwNonce, &pctMemLimit, &pctOpsLimit, &pctEnableDownload, &pctEnableCollect); err != nil { + &pctToken, &pctValidTill, &pctDeviceLimit, &pctCreatedAt, &pctUpdatedAt, &pctPwHash, &pctPwNonce, &pctMemLimit, &pctOpsLimit, &pctEnableDownload, &pctEnableCollect, &pctEnableJoin); err != nil { return nil, stacktrace.Propagate(err, "") } @@ -222,6 +222,7 @@ pct.access_token, pct.valid_till, pct.device_limit, pct.created_at, pct.updated_ EnableDownload: pctEnableDownload.Bool, EnableCollect: pctEnableCollect.Bool, PasswordEnabled: pctPwNonce.Valid, + EnableJoin: pctEnableJoin.Bool, } if pctPwNonce.Valid { url.Nonce = &pctPwNonce.String diff --git a/server/pkg/repo/public_collection.go b/server/pkg/repo/public_collection.go index 91b96fca78..49918f4c88 100644 --- a/server/pkg/repo/public_collection.go +++ b/server/pkg/repo/public_collection.go @@ -36,10 +36,15 @@ func (pcr *PublicCollectionRepository) GetAlbumUrl(token string) string { } func (pcr *PublicCollectionRepository) Insert(ctx context.Context, - cID int64, token string, validTill int64, deviceLimit int, enableCollect bool) error { + cID int64, token string, validTill int64, deviceLimit int, enableCollect bool, enableJoin *bool) error { + // default value for enableJoin is true + join := true + if enableJoin != nil { + join = *enableJoin + } _, err := pcr.DB.ExecContext(ctx, `INSERT INTO public_collection_tokens - (collection_id, access_token, valid_till, device_limit, enable_collect) VALUES ($1, $2, $3, $4, $5)`, - cID, token, validTill, deviceLimit, enableCollect) + (collection_id, access_token, valid_till, device_limit, enable_collect, enable_join) VALUES ($1, $2, $3, $4, $5, $6)`, + cID, token, validTill, deviceLimit, enableCollect, join) if err != nil && err.Error() == "pq: duplicate key value violates unique constraint \"public_active_collection_unique_idx\"" { return ente.ErrActiveLinkAlreadyExists } @@ -91,14 +96,14 @@ func (pcr *PublicCollectionRepository) GetCollectionToActivePublicURLMap(ctx con // Note: The token could be expired or deviceLimit is already reached func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx context.Context, collectionID int64) (ente.PublicCollectionToken, error) { row := pcr.DB.QueryRowContext(ctx, `SELECT id, collection_id, access_token, valid_till, device_limit, - is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect FROM + is_disabled, pw_hash, pw_nonce, mem_limit, ops_limit, enable_download, enable_collect, enable_join FROM public_collection_tokens WHERE collection_id = $1 and is_disabled = FALSE`, collectionID) //defer rows.Close() ret := ente.PublicCollectionToken{} err := row.Scan(&ret.ID, &ret.CollectionID, &ret.Token, &ret.ValidTill, &ret.DeviceLimit, - &ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload, &ret.EnableCollect) + &ret.IsDisabled, &ret.PassHash, &ret.Nonce, &ret.MemLimit, &ret.OpsLimit, &ret.EnableDownload, &ret.EnableCollect, &ret.EnableJoin) if err != nil { return ente.PublicCollectionToken{}, stacktrace.Propagate(err, "") } @@ -108,9 +113,9 @@ func (pcr *PublicCollectionRepository) GetActivePublicCollectionToken(ctx contex // UpdatePublicCollectionToken will update the row for corresponding public collection token func (pcr *PublicCollectionRepository) UpdatePublicCollectionToken(ctx context.Context, pct ente.PublicCollectionToken) error { _, err := pcr.DB.ExecContext(ctx, `UPDATE public_collection_tokens SET valid_till = $1, device_limit = $2, - pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7, enable_collect = $8 - where id = $9`, - pct.ValidTill, pct.DeviceLimit, pct.PassHash, pct.Nonce, pct.MemLimit, pct.OpsLimit, pct.EnableDownload, pct.EnableCollect, pct.ID) + pw_hash = $3, pw_nonce = $4, mem_limit = $5, ops_limit = $6, enable_download = $7, enable_collect = $8, enable_join = $9 + where id = $10`, + pct.ValidTill, pct.DeviceLimit, pct.PassHash, pct.Nonce, pct.MemLimit, pct.OpsLimit, pct.EnableDownload, pct.EnableCollect, pct.EnableJoin, pct.ID) return stacktrace.Propagate(err, "failed to update public collection token") }