diff --git a/api/handle_license.go b/api/handle_license.go index 88af8f12..e835bd9e 100644 --- a/api/handle_license.go +++ b/api/handle_license.go @@ -98,3 +98,12 @@ func handleValidateLicense(uc usecases.Usecases) func(c *gin.Context) { c.JSON(http.StatusOK, dto.AdaptLicenseValidationDto(licenseValidation)) } } + +func handleIsSSOEnabled(uc usecases.Usecases) func(c *gin.Context) { + return func(c *gin.Context) { + usecase := uc.NewLicenseUsecase() + c.JSON(http.StatusOK, gin.H{ + "is_sso_enabled": usecase.HasSsoEnabled(), + }) + } +} diff --git a/api/handle_organization.go b/api/handle_organization.go index 4ef40b56..b195415c 100644 --- a/api/handle_organization.go +++ b/api/handle_organization.go @@ -102,3 +102,40 @@ func handleDeleteOrganization(uc usecases.Usecases) func(c *gin.Context) { c.Status(http.StatusNoContent) } } + +func handleGetOrganizationFeatureAccess(uc usecases.Usecases) func(c *gin.Context) { + return func(c *gin.Context) { + ctx := c.Request.Context() + organizationID := c.Param("organization_id") + + usecase := usecasesWithCreds(ctx, uc).NewOrganizationUseCase() + featureAccess, err := usecase.GetOrganizationFeatureAccess(ctx, organizationID) + if presentError(ctx, c, err) { + return + } + c.JSON(http.StatusOK, gin.H{ + "feature_access": dto.AdaptOrganizationFeatureAccessDto(featureAccess), + }) + } +} + +func handlePatchOrganizationFeatureAccess(uc usecases.Usecases) func(c *gin.Context) { + return func(c *gin.Context) { + ctx := c.Request.Context() + organizationID := c.Param("organization_id") + var data dto.UpdateOrganizationFeatureAccessBodyDto + if err := c.ShouldBindJSON(&data); err != nil { + c.Status(http.StatusBadRequest) + return + } + + usecase := usecasesWithCreds(ctx, uc).NewOrganizationUseCase() + err := usecase.UpdateOrganizationFeatureAccess(ctx, + dto.AdaptUpdateOrganizationFeatureAccessInput(data, organizationID)) + if presentError(ctx, c, err) { + return + } + + c.Status(http.StatusNoContent) + } +} diff --git a/api/routes.go b/api/routes.go index d648146a..0447a76a 100644 --- a/api/routes.go +++ b/api/routes.go @@ -27,6 +27,7 @@ func addRoutes(r *gin.Engine, conf Configuration, uc usecases.Usecases, auth Aut r.GET("/liveness", tom, handleLivenessProbe(uc)) r.POST("/token", tom, tokenHandler.GenerateToken) r.GET("/validate-license/*license_key", tom, handleValidateLicense(uc)) + r.GET("/is-sso-available", tom, handleIsSSOEnabled(uc)) router := r.Use(auth.Middleware) @@ -129,6 +130,9 @@ func addRoutes(r *gin.Engine, conf Configuration, uc usecases.Usecases, auth Aut router.GET("/organizations/:organization_id", tom, handleGetOrganization(uc)) router.PATCH("/organizations/:organization_id", tom, handlePatchOrganization(uc)) router.DELETE("/organizations/:organization_id", tom, handleDeleteOrganization(uc)) + router.GET("/organizations/:organization_id/feature_access", tom, handleGetOrganizationFeatureAccess(uc)) + router.PATCH("/organizations/:organization_id/feature_access", tom, + handlePatchOrganizationFeatureAccess(uc)) router.GET("/partners", tom, handleListPartners(uc)) router.POST("/partners", tom, handleCreatePartner(uc)) diff --git a/dto/license_dto.go b/dto/license_dto.go index aa7e7ef8..6d0b61e7 100644 --- a/dto/license_dto.go +++ b/dto/license_dto.go @@ -19,6 +19,7 @@ type LicenseEntitlements struct { Webhooks bool `json:"webhooks"` RuleSnoozes bool `json:"rule_snoozes"` TestRun bool `json:"test_run"` + Sanctions bool `json:"sanctions"` } func AdaptLicenseEntitlements(licenseEntitlements models.LicenseEntitlements) LicenseEntitlements { @@ -31,6 +32,7 @@ func AdaptLicenseEntitlements(licenseEntitlements models.LicenseEntitlements) Li Webhooks: licenseEntitlements.Webhooks, RuleSnoozes: licenseEntitlements.RuleSnoozes, TestRun: licenseEntitlements.TestRun, + Sanctions: licenseEntitlements.Sanctions, } } diff --git a/dto/organization_feature_access_dto.go b/dto/organization_feature_access_dto.go new file mode 100644 index 00000000..38b710c0 --- /dev/null +++ b/dto/organization_feature_access_dto.go @@ -0,0 +1,48 @@ +package dto + +import ( + "github.com/checkmarble/marble-backend/models" + "github.com/checkmarble/marble-backend/utils" +) + +type APIOrganizationFeatureAccess struct { + TestRun string `json:"test_run"` + Workflows string `json:"workflows"` + Webhooks string `json:"webhooks"` + RuleSnoozes string `json:"rule_snoozes"` + Roles string `json:"roles"` + Analytics string `json:"analytics"` + Sanctions string `json:"sanctions"` +} + +func AdaptOrganizationFeatureAccessDto(f models.OrganizationFeatureAccess) APIOrganizationFeatureAccess { + return APIOrganizationFeatureAccess{ + TestRun: f.TestRun.String(), + Workflows: f.Workflows.String(), + Webhooks: f.Webhooks.String(), + RuleSnoozes: f.RuleSnoozes.String(), + Roles: f.Roles.String(), + Analytics: f.Analytics.String(), + Sanctions: f.Sanctions.String(), + } +} + +type UpdateOrganizationFeatureAccessBodyDto struct { + TestRun *string `json:"test_run"` + Sanctions *string `json:"sanctions"` +} + +func AdaptUpdateOrganizationFeatureAccessInput(f UpdateOrganizationFeatureAccessBodyDto, orgId string) models.UpdateOrganizationFeatureAccessInput { + var testRun, sanctions *models.FeatureAccess + if f.TestRun != nil { + testRun = utils.Ptr(models.FeatureAccessFrom(*f.TestRun)) + } + if f.Sanctions != nil { + sanctions = utils.Ptr(models.FeatureAccessFrom(*f.Sanctions)) + } + return models.UpdateOrganizationFeatureAccessInput{ + OrganizationId: orgId, + TestRun: testRun, + Sanctions: sanctions, + } +} diff --git a/models/feature_access.go b/models/feature_access.go new file mode 100644 index 00000000..fa0b3b6c --- /dev/null +++ b/models/feature_access.go @@ -0,0 +1,38 @@ +package models + +type FeatureAccess int + +const ( + Restricted FeatureAccess = iota + Allowed + Test + UnknownFeatureAccess +) + +var ValidFeaturesAccess = []FeatureAccess{Allowed, Restricted, Test} + +// Provide a string value for each outcome +func (f FeatureAccess) String() string { + switch f { + case Allowed: + return "allowed" + case Restricted: + return "restricted" + case Test: + return "test" + } + return "unknown" +} + +// Provide an Outcome from a string value +func FeatureAccessFrom(s string) FeatureAccess { + switch s { + case "allowed": + return Allowed + case "restricted": + return Restricted + case "test": + return Test + } + return UnknownFeatureAccess +} diff --git a/models/license.go b/models/license.go index 38c10ad5..506b4819 100644 --- a/models/license.go +++ b/models/license.go @@ -59,6 +59,7 @@ type LicenseEntitlements struct { Webhooks bool RuleSnoozes bool TestRun bool + Sanctions bool } type LicenseValidation struct { @@ -78,6 +79,7 @@ func NewFullLicense() LicenseValidation { Webhooks: true, RuleSnoozes: true, TestRun: true, + Sanctions: true, }, } } diff --git a/models/organization_feature_access.go b/models/organization_feature_access.go new file mode 100644 index 00000000..723ea4ea --- /dev/null +++ b/models/organization_feature_access.go @@ -0,0 +1,69 @@ +package models + +import "time" + +type OrganizationFeatureAccess struct { + Id string + OrganizationId string + TestRun FeatureAccess + Workflows FeatureAccess + Webhooks FeatureAccess + RuleSnoozes FeatureAccess + Roles FeatureAccess + Analytics FeatureAccess + Sanctions FeatureAccess + CreatedAt time.Time + UpdatedAt time.Time +} + +type DbStoredOrganizationFeatureAccess struct { + Id string + OrganizationId string + TestRun FeatureAccess + Sanctions FeatureAccess + CreatedAt time.Time + UpdatedAt time.Time +} + +type UpdateOrganizationFeatureAccessInput struct { + OrganizationId string + TestRun *FeatureAccess + Sanctions *FeatureAccess +} + +func (f DbStoredOrganizationFeatureAccess) MergeWithLicenseEntitlement(l *LicenseEntitlements) OrganizationFeatureAccess { + o := OrganizationFeatureAccess{ + Id: f.Id, + OrganizationId: f.OrganizationId, + TestRun: f.TestRun, + Sanctions: f.Sanctions, + CreatedAt: f.CreatedAt, + UpdatedAt: f.UpdatedAt, + } + // First, set the feature accesses to "allowed" if the license allows it + if l.Analytics { + o.Analytics = Allowed + } + if l.Webhooks { + o.Webhooks = Allowed + } + if l.Workflows { + o.Workflows = Allowed + } + if l.RuleSnoozes { + o.RuleSnoozes = Allowed + } + if l.UserRoles { + o.Roles = Allowed + } + + // remove the feature accesses that are not allowed by the license + if !l.TestRun { + o.TestRun = Restricted + } + if !l.Sanctions { + o.Sanctions = Restricted + } + + return o +} diff --git a/repositories/dbmodels/db_license.go b/repositories/dbmodels/db_license.go index 73a898de..9d9ff825 100644 --- a/repositories/dbmodels/db_license.go +++ b/repositories/dbmodels/db_license.go @@ -25,6 +25,7 @@ type DBLicense struct { Webhooks bool `db:"webhooks"` RuleSnoozes bool `db:"rule_snoozes"` TestRun bool `db:"test_run"` + Sanctions bool `db:"sanctions"` } const TABLE_LICENSES = "licenses" @@ -49,6 +50,7 @@ func AdaptLicense(db DBLicense) (models.License, error) { Webhooks: db.Webhooks, RuleSnoozes: db.RuleSnoozes, TestRun: db.TestRun, + Sanctions: db.Sanctions, }, }, nil } diff --git a/repositories/dbmodels/db_organization_feature_access.go b/repositories/dbmodels/db_organization_feature_access.go new file mode 100644 index 00000000..4535ad7e --- /dev/null +++ b/repositories/dbmodels/db_organization_feature_access.go @@ -0,0 +1,32 @@ +package dbmodels + +import ( + "time" + + "github.com/checkmarble/marble-backend/models" + "github.com/checkmarble/marble-backend/utils" +) + +const TABLE_ORGANIZATION_FEATURE_ACCESS = "organization_feature_access" + +var SelectOrganizationFeatureAccessColumn = utils.ColumnList[DBOrganizationFeatureAccess]() + +type DBOrganizationFeatureAccess struct { + Id string `db:"id"` + OrganizationId string `db:"org_id"` + TestRun string `db:"test_run"` + Sanctions string `db:"sanctions"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +func AdaptOrganizationFeatureAccess(db DBOrganizationFeatureAccess) (models.DbStoredOrganizationFeatureAccess, error) { + return models.DbStoredOrganizationFeatureAccess{ + Id: db.Id, + OrganizationId: db.OrganizationId, + TestRun: models.FeatureAccessFrom(db.TestRun), + Sanctions: models.FeatureAccessFrom(db.Sanctions), + CreatedAt: db.CreatedAt, + UpdatedAt: db.UpdatedAt, + }, nil +} diff --git a/repositories/migrations/20250108102844_add_organization_feature_access_table.sql b/repositories/migrations/20250108102844_add_organization_feature_access_table.sql new file mode 100644 index 00000000..9157ed16 --- /dev/null +++ b/repositories/migrations/20250108102844_add_organization_feature_access_table.sql @@ -0,0 +1,36 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE + organization_feature_access ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4 (), + org_id UUID NOT NULL, + test_run VARCHAR NOT NULL DEFAULT 'allowed', + sanctions VARCHAR NOT NULL DEFAULT 'allowed', + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT fk_org FOREIGN KEY (org_id) REFERENCES organizations (id) ON DELETE CASCADE + ); + +INSERT INTO + organization_feature_access (org_id) +SELECT + id +FROM + organizations; + +CREATE UNIQUE INDEX unique_organization_feature_access ON organization_feature_access (org_id); + +ALTER TABLE licenses +ADD COLUMN sanctions BOOLEAN NOT NULL DEFAULT FALSE; + +-- +goose StatementEnd +-- +goose Down +-- +goose StatementBegin +DROP INDEX unique_organization_feature_access; + +DROP TABLE organization_feature_access; + +ALTER TABLE licenses +DROP COLUMN sanctions; + +-- +goose StatementEnd \ No newline at end of file diff --git a/repositories/organization_repository.go b/repositories/organization_repository.go index 1da25676..4b90b4c6 100644 --- a/repositories/organization_repository.go +++ b/repositories/organization_repository.go @@ -3,6 +3,7 @@ package repositories import ( "context" + "github.com/Masterminds/squirrel" "github.com/checkmarble/marble-backend/models" "github.com/checkmarble/marble-backend/repositories/dbmodels" "github.com/checkmarble/marble-backend/utils" @@ -16,6 +17,14 @@ type OrganizationRepository interface { UpdateOrganization(ctx context.Context, exec Executor, updateOrganization models.UpdateOrganizationInput) error DeleteOrganization(ctx context.Context, exec Executor, organizationId string) error DeleteOrganizationDecisionRulesAsync(ctx context.Context, exec Executor, organizationId string) + GetOrganizationFeatureAccess(ctx context.Context, exec Executor, organizationId string) ( + models.DbStoredOrganizationFeatureAccess, error, + ) + UpdateOrganizationFeatureAccess( + ctx context.Context, + exec Executor, + updateFeatureAccess models.UpdateOrganizationFeatureAccessInput, + ) error } type OrganizationRepositoryPostgresql struct{} @@ -76,7 +85,16 @@ func (repo *OrganizationRepositoryPostgresql) CreateOrganization( name, ), ) - return err + if err != nil { + return err + } + + newErr := ExecBuilder(ctx, exec, NewQueryBuilder(). + Insert(dbmodels.TABLE_ORGANIZATION_FEATURE_ACCESS). + Columns("org_id"). + Values(newOrganizationId)) + + return newErr } func (repo *OrganizationRepositoryPostgresql) UpdateOrganization(ctx context.Context, exec Executor, updateOrganization models.UpdateOrganizationInput) error { @@ -124,3 +142,54 @@ func (repo *OrganizationRepositoryPostgresql) DeleteOrganizationDecisionRulesAsy } }() } + +func (repo *OrganizationRepositoryPostgresql) GetOrganizationFeatureAccess(ctx context.Context, exec Executor, + organizationId string, +) (models.DbStoredOrganizationFeatureAccess, error) { + if err := validateMarbleDbExecutor(exec); err != nil { + return models.DbStoredOrganizationFeatureAccess{}, err + } + + return SqlToModel( + ctx, + exec, + NewQueryBuilder(). + Select(dbmodels.SelectOrganizationFeatureAccessColumn...). + From(dbmodels.TABLE_ORGANIZATION_FEATURE_ACCESS). + Where(squirrel.Eq{"org_id": organizationId}), + dbmodels.AdaptOrganizationFeatureAccess, + ) +} + +func (repo *OrganizationRepositoryPostgresql) UpdateOrganizationFeatureAccess( + ctx context.Context, + exec Executor, + updateFeatureAccess models.UpdateOrganizationFeatureAccessInput, +) error { + if err := validateMarbleDbExecutor(exec); err != nil { + return err + } + + query := NewQueryBuilder(). + Update(dbmodels.TABLE_ORGANIZATION_FEATURE_ACCESS). + Where(squirrel.Eq{"org_id": updateFeatureAccess.OrganizationId}) + + nbUpdated := 0 + if updateFeatureAccess.TestRun != nil { + query = query.Set("test_run", *updateFeatureAccess.TestRun) + nbUpdated++ + } + if updateFeatureAccess.Sanctions != nil { + query = query.Set("sanctions", *updateFeatureAccess.Sanctions) + nbUpdated++ + } + + if nbUpdated == 0 { + return nil + } + + query.Set("updated_at", squirrel.Expr("NOW()")) + + err := ExecBuilder(ctx, exec, query) + return err +} diff --git a/usecases/license_usecase.go b/usecases/license_usecase.go index f39f0efa..8f5897ee 100644 --- a/usecases/license_usecase.go +++ b/usecases/license_usecase.go @@ -125,6 +125,19 @@ type publicLicenseRepository interface { type PublicLicenseUseCase struct { executorFactory executor_factory.ExecutorFactory licenseRepository publicLicenseRepository + license models.LicenseValidation +} + +func NewPublicLicenseUsecase( + executorFactory executor_factory.ExecutorFactory, + publicLicenseRepository publicLicenseRepository, + license models.LicenseValidation, +) PublicLicenseUseCase { + return PublicLicenseUseCase{ + executorFactory: executorFactory, + licenseRepository: publicLicenseRepository, + license: license, + } } func (usecase *PublicLicenseUseCase) ValidateLicense(ctx context.Context, licenseKey string) (models.LicenseValidation, error) { @@ -155,3 +168,7 @@ func (usecase *PublicLicenseUseCase) ValidateLicense(ctx context.Context, licens LicenseEntitlements: license.LicenseEntitlements, }, nil } + +func (usecase *PublicLicenseUseCase) HasSsoEnabled() bool { + return usecase.license.Sso +} diff --git a/usecases/organization_usecase.go b/usecases/organization_usecase.go index bbc4c7b6..7700ead9 100644 --- a/usecases/organization_usecase.go +++ b/usecases/organization_usecase.go @@ -21,6 +21,31 @@ type OrganizationUseCase struct { organizationCreator organization.OrganizationCreator organizationSchemaRepository repositories.OrganizationSchemaRepository executorFactory executor_factory.ExecutorFactory + license models.LicenseValidation +} + +func NewOrganizationUseCase( + enforceSecurity security.EnforceSecurityOrganization, + transactionFactory executor_factory.TransactionFactory, + organizationRepository repositories.OrganizationRepository, + datamodelRepository repositories.DataModelRepository, + userRepository repositories.UserRepository, + organizationCreator organization.OrganizationCreator, + organizationSchemaRepository repositories.OrganizationSchemaRepository, + executorFactory executor_factory.ExecutorFactory, + license models.LicenseValidation, +) OrganizationUseCase { + return OrganizationUseCase{ + enforceSecurity: enforceSecurity, + transactionFactory: transactionFactory, + organizationRepository: organizationRepository, + datamodelRepository: datamodelRepository, + userRepository: userRepository, + organizationCreator: organizationCreator, + organizationSchemaRepository: organizationSchemaRepository, + executorFactory: executorFactory, + license: license, + } } func (usecase *OrganizationUseCase) GetOrganizations(ctx context.Context) ([]models.Organization, error) { @@ -109,3 +134,31 @@ func (usecase *OrganizationUseCase) DeleteOrganization(ctx context.Context, orga ) return nil } + +func (usecase *OrganizationUseCase) GetOrganizationFeatureAccess(ctx context.Context, + organizationId string, +) (models.OrganizationFeatureAccess, error) { + if err := usecase.enforceSecurity.ReadOrganization(organizationId); err != nil { + return models.OrganizationFeatureAccess{}, err + } + + dbStoredFeatureAccess, err := usecase.organizationRepository.GetOrganizationFeatureAccess(ctx, + usecase.executorFactory.NewExecutor(), organizationId) + if err != nil { + return models.OrganizationFeatureAccess{}, err + } + + return dbStoredFeatureAccess.MergeWithLicenseEntitlement( + &usecase.license.LicenseEntitlements), nil +} + +func (usecase *OrganizationUseCase) UpdateOrganizationFeatureAccess( + ctx context.Context, + featureAccess models.UpdateOrganizationFeatureAccessInput, +) error { + if err := usecase.enforceSecurity.CreateOrganization(); err != nil { + return err + } + return usecase.organizationRepository.UpdateOrganizationFeatureAccess(ctx, + usecase.executorFactory.NewExecutor(), featureAccess) +} diff --git a/usecases/usecases.go b/usecases/usecases.go index 39e52a18..8bd50755 100644 --- a/usecases/usecases.go +++ b/usecases/usecases.go @@ -242,10 +242,11 @@ func (usecase *Usecases) NewScenarioFetcher() scenarios.ScenarioFetcher { } func (usecases *Usecases) NewLicenseUsecase() PublicLicenseUseCase { - return PublicLicenseUseCase{ - executorFactory: usecases.NewExecutorFactory(), - licenseRepository: &usecases.Repositories.MarbleDbRepository, - } + return NewPublicLicenseUsecase( + usecases.NewExecutorFactory(), + &usecases.Repositories.MarbleDbRepository, + usecases.license, + ) } func (usecases *Usecases) NewTaskQueueWorker(riverClient *river.Client[pgx.Tx]) *TaskQueueWorker { diff --git a/usecases/usecases_with_creds.go b/usecases/usecases_with_creds.go index a23979a8..44950fb8 100644 --- a/usecases/usecases_with_creds.go +++ b/usecases/usecases_with_creds.go @@ -196,16 +196,17 @@ func (usecases *UsecasesWithCreds) NewClientDbIndexEditor() clientDbIndexEditor } func (usecases *UsecasesWithCreds) NewOrganizationUseCase() OrganizationUseCase { - return OrganizationUseCase{ - enforceSecurity: usecases.NewEnforceOrganizationSecurity(), - executorFactory: usecases.NewExecutorFactory(), - transactionFactory: usecases.NewTransactionFactory(), - organizationRepository: usecases.Repositories.OrganizationRepository, - datamodelRepository: usecases.Repositories.DataModelRepository, - userRepository: usecases.Repositories.UserRepository, - organizationCreator: usecases.NewOrganizationCreator(), - organizationSchemaRepository: usecases.Repositories.OrganizationSchemaRepository, - } + return NewOrganizationUseCase( + usecases.NewEnforceOrganizationSecurity(), + usecases.NewTransactionFactory(), + usecases.Repositories.OrganizationRepository, + usecases.Repositories.DataModelRepository, + usecases.Repositories.UserRepository, + usecases.NewOrganizationCreator(), + usecases.Repositories.OrganizationSchemaRepository, + usecases.NewExecutorFactory(), + usecases.Usecases.license, + ) } func (usecases *UsecasesWithCreds) NewDataModelUseCase() DataModelUseCase {