diff --git a/api/r0/upload.go b/api/r0/upload.go index 74d7eef8..1f7b803c 100644 --- a/api/r0/upload.go +++ b/api/r0/upload.go @@ -13,6 +13,7 @@ import ( "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/controllers/info_controller" "github.com/turt2live/matrix-media-repo/controllers/upload_controller" + "github.com/turt2live/matrix-media-repo/quota" "github.com/turt2live/matrix-media-repo/util/cleanup" ) @@ -44,6 +45,17 @@ func UploadMedia(r *http.Request, rctx rcontext.RequestContext, user api.UserInf return api.RequestTooSmall() } + inQuota, err := quota.IsUserWithinQuota(rctx, user.UserId) + if err != nil { + io.Copy(ioutil.Discard, r.Body) // Ditch the entire request + rctx.Log.Error("Unexpected error checking quota: " + err.Error()) + return api.InternalServerError("Unexpected Error") + } + if !inQuota { + io.Copy(ioutil.Discard, r.Body) // Ditch the entire request + return api.QuotaExceeded() + } + contentLength := upload_controller.EstimateContentLength(r.ContentLength, r.Header.Get("Content-Length")) media, err := upload_controller.UploadMedia(r.Body, contentLength, contentType, filename, user.UserId, r.Host, rctx) diff --git a/api/responses.go b/api/responses.go index f439191e..d032ce80 100644 --- a/api/responses.go +++ b/api/responses.go @@ -49,3 +49,7 @@ func AuthFailed() *ErrorResponse { func BadRequest(message string) *ErrorResponse { return &ErrorResponse{common.ErrCodeUnknown, message, common.ErrCodeBadRequest} } + +func QuotaExceeded() *ErrorResponse { + return &ErrorResponse{common.ErrCodeForbidden, "Quota Exceeded", common.ErrCodeQuotaExceeded} +} diff --git a/api/webserver/route_handler.go b/api/webserver/route_handler.go index cbf09359..eff47562 100644 --- a/api/webserver/route_handler.go +++ b/api/webserver/route_handler.go @@ -150,6 +150,9 @@ func (h handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case common.ErrCodeMethodNotAllowed: statusCode = http.StatusMethodNotAllowed break + case common.ErrCodeForbidden: + statusCode = http.StatusForbidden + break default: // Treat as unknown (a generic server error) statusCode = http.StatusInternalServerError break diff --git a/common/config/conf_min_shared.go b/common/config/conf_min_shared.go index fb212855..00408488 100644 --- a/common/config/conf_min_shared.go +++ b/common/config/conf_min_shared.go @@ -23,6 +23,10 @@ func NewDefaultMinimumRepoConfig() MinimumRepoConfig { MaxSizeBytes: 104857600, // 100mb MinSizeBytes: 100, ReportedMaxSizeBytes: 0, + Quota: QuotasConfig{ + Enabled: false, + UserQuotas: []QuotaUserConfig{}, + }, }, Identicons: IdenticonsConfig{ Enabled: true, diff --git a/common/config/models_domain.go b/common/config/models_domain.go index 8ede29d9..caaf7286 100644 --- a/common/config/models_domain.go +++ b/common/config/models_domain.go @@ -6,10 +6,21 @@ type ArchivingConfig struct { TargetBytesPerPart int64 `yaml:"targetBytesPerPart"` } +type QuotaUserConfig struct { + Glob string `yaml:"glob"` + MaxBytes int64 `yaml:"maxBytes"` +} + +type QuotasConfig struct { + Enabled bool `yaml:"enabled"` + UserQuotas []QuotaUserConfig `yaml:"users,flow"` +} + type UploadsConfig struct { - MaxSizeBytes int64 `yaml:"maxBytes"` - MinSizeBytes int64 `yaml:"minBytes"` - ReportedMaxSizeBytes int64 `yaml:"reportedMaxBytes"` + MaxSizeBytes int64 `yaml:"maxBytes"` + MinSizeBytes int64 `yaml:"minBytes"` + ReportedMaxSizeBytes int64 `yaml:"reportedMaxBytes"` + Quota QuotasConfig `yaml:"quotas"` } type DatastoreConfig struct { diff --git a/common/errorcodes.go b/common/errorcodes.go index a8cc23ee..8deb6494 100644 --- a/common/errorcodes.go +++ b/common/errorcodes.go @@ -13,3 +13,5 @@ const ErrCodeMethodNotAllowed = "M_METHOD_NOT_ALLOWED" const ErrCodeBadRequest = "M_BAD_REQUEST" const ErrCodeRateLimitExceeded = "M_LIMIT_EXCEEDED" const ErrCodeUnknown = "M_UNKNOWN" +const ErrCodeForbidden = "M_FORBIDDEN" +const ErrCodeQuotaExceeded = "M_QUOTA_EXCEEDED" diff --git a/config.sample.yaml b/config.sample.yaml index 57ec20a4..04b35845 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -195,6 +195,7 @@ archiving: # The file upload settings for the media repository uploads: + # The maximum individual file size a user can upload. maxBytes: 104857600 # 100MB default, 0 to disable # The minimum number of bytes to let people upload. This is recommended to be non-zero to @@ -210,6 +211,24 @@ uploads: # Set this to -1 to indicate that there is no limit. Zero will force the use of maxBytes. #reportedMaxBytes: 104857600 + # Options for limiting how much content a user can upload. Quotas are applied to content + # associated with a user regardless of de-duplication. Quotas which affect remote servers + # or users will not take effect. When a user exceeds their quota they will be unable to + # upload any more media. + quotas: + # Whether or not quotas are enabled/enforced. Note that even when disabled the media repo + # will track how much media a user has uploaded. This is disabled by default. + enabled: false + + # The quota rules that affect users. The first rule to match the uploader will take effect. + # An implied rule which matches all users and has no quota is always last in this list, + # meaning that if no rules are supplied then users will be able to upload anything. Similarly, + # if no rules match a user then the implied rule will match, allowing the user to have no + # quota. The quota will let the user upload to 1 media past their quota, meaning that from + # a statistics perspective the user might exceed their quota however only by a small amount. + users: + - glob: "@*:*" # Affect all users. Use asterisks (*) to match any character. + maxBytes: 53687063712 # 50GB default, 0 to disable # Settings related to downloading files from the media repository downloads: diff --git a/migrations/17_add_user_stats_table_down.sql b/migrations/17_add_user_stats_table_down.sql new file mode 100644 index 00000000..287f0eb0 --- /dev/null +++ b/migrations/17_add_user_stats_table_down.sql @@ -0,0 +1,3 @@ +DROP TRIGGER media_change_for_user; +DELETE FUNCTION track_update_user_media(); +DROP TABLE user_stats; diff --git a/migrations/17_add_user_stats_table_up.sql b/migrations/17_add_user_stats_table_up.sql new file mode 100644 index 00000000..67f1b85c --- /dev/null +++ b/migrations/17_add_user_stats_table_up.sql @@ -0,0 +1,40 @@ +CREATE TABLE IF NOT EXISTS user_stats ( + user_id TEXT PRIMARY KEY NOT NULL, + uploaded_bytes BIGINT NOT NULL +); +CREATE OR REPLACE FUNCTION track_update_user_media() + RETURNS TRIGGER + LANGUAGE PLPGSQL + AS +$$ +BEGIN + IF TG_OP = 'UPDATE' THEN + INSERT INTO user_stats (user_id, uploaded_bytes) VALUES (NEW.user_id, 0) ON CONFLICT (user_id) DO NOTHING; + INSERT INTO user_stats (user_id, uploaded_bytes) VALUES (OLD.user_id, 0) ON CONFLICT (user_id) DO NOTHING; + + IF NEW.user_id <> OLD.user_id THEN + UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes - OLD.size_bytes WHERE user_stats.user_id = OLD.user_id; + UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes + NEW.size_bytes WHERE user_stats.user_id = NEW.user_id; + ELSIF NEW.size_bytes <> OLD.size_bytes THEN + UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes - OLD.size_bytes + NEW.size_bytes WHERE user_stats.user_id = NEW.user_id; + END IF; + RETURN NEW; + ELSIF TG_OP = 'DELETE' THEN + UPDATE user_stats SET uploaded_bytes = user_stats.uploaded_bytes - OLD.size_bytes WHERE user_stats.user_id = OLD.user_id; + RETURN OLD; + ELSIF TG_OP = 'INSERT' THEN + INSERT INTO user_stats (user_id, uploaded_bytes) VALUES (NEW.user_id, NEW.size_bytes) ON CONFLICT (user_id) DO UPDATE SET uploaded_bytes = user_stats.uploaded_bytes + NEW.size_bytes; + RETURN NEW; + END IF; +END; +$$; +DROP TRIGGER IF EXISTS media_change_for_user ON media; +CREATE TRIGGER media_change_for_user AFTER INSERT OR UPDATE OR DELETE ON media FOR EACH ROW EXECUTE PROCEDURE track_update_user_media(); + +-- Populate the new table +DO $$ +BEGIN + IF ((SELECT COUNT(*) FROM user_stats)) = 0 THEN + INSERT INTO user_stats SELECT user_id, SUM(size_bytes) FROM media GROUP BY user_id; + END IF; +END $$; diff --git a/quota/quota.go b/quota/quota.go new file mode 100644 index 00000000..1a8649ff --- /dev/null +++ b/quota/quota.go @@ -0,0 +1,35 @@ +package quota + +import ( + "database/sql" + + "github.com/ryanuber/go-glob" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/storage" +) + +func IsUserWithinQuota(ctx rcontext.RequestContext, userId string) (bool, error) { + if !ctx.Config.Uploads.Quota.Enabled { + return true, nil + } + + db := storage.GetDatabase().GetMetadataStore(ctx) + stat, err := db.GetUserStats(userId) + if err == sql.ErrNoRows { + return true, nil // no stats == within quota + } + if err != nil { + return false, err + } + + for _, q := range ctx.Config.Uploads.Quota.UserQuotas { + if glob.Glob(q.Glob, userId) { + if q.MaxBytes == 0 { + return true, nil // infinite quota + } + return stat.UploadedBytes < q.MaxBytes, nil + } + } + + return true, nil // no rules == no quota +} diff --git a/storage/stores/metadata_store.go b/storage/stores/metadata_store.go index 28c5f712..82a4c22f 100644 --- a/storage/stores/metadata_store.go +++ b/storage/stores/metadata_store.go @@ -31,6 +31,7 @@ const selectReservation = "SELECT origin, media_id, reason FROM reserved_media W const selectMediaLastAccessed = "SELECT m.sha256_hash, m.size_bytes, m.datastore_id, m.location, m.creation_ts, a.last_access_ts FROM media AS m JOIN last_access AS a ON m.sha256_hash = a.sha256_hash WHERE a.last_access_ts < $1;" const insertBlurhash = "INSERT INTO blurhashes (sha256_hash, blurhash) VALUES ($1, $2);" const selectBlurhash = "SELECT blurhash FROM blurhashes WHERE sha256_hash = $1;" +const selectUserStats = "SELECT user_id, uploaded_bytes FROM user_stats WHERE user_id = $1;" type metadataStoreStatements struct { upsertLastAccessed *sql.Stmt @@ -51,6 +52,7 @@ type metadataStoreStatements struct { selectMediaLastAccessed *sql.Stmt insertBlurhash *sql.Stmt selectBlurhash *sql.Stmt + selectUserStats *sql.Stmt } type MetadataStoreFactory struct { @@ -124,6 +126,9 @@ func InitMetadataStore(sqlDb *sql.DB) (*MetadataStoreFactory, error) { if store.stmts.selectBlurhash, err = store.sqlDb.Prepare(selectBlurhash); err != nil { return nil, err } + if store.stmts.selectUserStats, err = store.sqlDb.Prepare(selectUserStats); err != nil { + return nil, err + } return &store, nil } @@ -408,3 +413,17 @@ func (s *MetadataStore) GetBlurhash(sha256Hash string) (string, error) { } return blurhash, nil } + +func (s *MetadataStore) GetUserStats(userId string) (*types.UserStats, error) { + r := s.statements.selectUserStats.QueryRowContext(s.ctx, userId) + + stat := &types.UserStats{} + err := r.Scan( + &stat.UserId, + &stat.UploadedBytes, + ) + if err != nil { + return nil, err + } + return stat, nil +} diff --git a/types/stats.go b/types/stats.go new file mode 100644 index 00000000..01188c9d --- /dev/null +++ b/types/stats.go @@ -0,0 +1,6 @@ +package types + +type UserStats struct { + UserId string + UploadedBytes int64 +}