Skip to content

Commit

Permalink
Fix studio name uniqueness validation (#4454)
Browse files Browse the repository at this point in the history
  • Loading branch information
WithoutPants authored Jan 14, 2024
1 parent 08b7358 commit 5cf28cf
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 8 deletions.
8 changes: 1 addition & 7 deletions internal/api/resolver_mutation_studio.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,10 @@ func (r *mutationResolver) StudioCreate(ctx context.Context, input models.Studio
if err := r.withTxn(ctx, func(ctx context.Context) error {
qb := r.repository.Studio

if err := studio.EnsureStudioNameUnique(ctx, 0, newStudio.Name, qb); err != nil {
if err := studio.ValidateCreate(ctx, newStudio, qb); err != nil {
return err
}

if len(input.Aliases) > 0 {
if err := studio.EnsureAliasesUnique(ctx, 0, input.Aliases, qb); err != nil {
return err
}
}

err = qb.Create(ctx, &newStudio)
if err != nil {
return err
Expand Down
4 changes: 4 additions & 0 deletions internal/manager/task_stash_box_tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ func (t *StashBoxBatchTagTask) processMatchedStudio(ctx context.Context, s *mode
err = r.WithTxn(ctx, func(ctx context.Context) error {
qb := r.Studio

if err := studio.ValidateCreate(ctx, *newStudio, qb); err != nil {
return err
}

if err := qb.Create(ctx, newStudio); err != nil {
return err
}
Expand Down
29 changes: 28 additions & 1 deletion pkg/studio/update.go → pkg/studio/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
)

var (
ErrNameMissing = errors.New("studio name must not be blank")
ErrStudioOwnAncestor = errors.New("studio cannot be an ancestor of itself")
)

Expand Down Expand Up @@ -70,6 +71,32 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb model
return nil
}

func ValidateCreate(ctx context.Context, studio models.Studio, qb models.StudioQueryer) error {
if err := validateName(ctx, 0, studio.Name, qb); err != nil {
return err
}

if studio.Aliases.Loaded() && len(studio.Aliases.List()) > 0 {
if err := EnsureAliasesUnique(ctx, 0, studio.Aliases.List(), qb); err != nil {
return err
}
}

return nil
}

func validateName(ctx context.Context, studioID int, name string, qb models.StudioQueryer) error {
if name == "" {
return ErrNameMissing
}

if err := EnsureStudioNameUnique(ctx, studioID, name, qb); err != nil {
return err
}

return nil
}

type ValidateModifyReader interface {
models.StudioGetter
models.StudioQueryer
Expand Down Expand Up @@ -110,7 +137,7 @@ func ValidateModify(ctx context.Context, s models.StudioPartial, qb ValidateModi
}

if s.Name.Set && s.Name.Value != existing.Name {
if err := EnsureStudioNameUnique(ctx, 0, s.Name.Value, qb); err != nil {
if err := validateName(ctx, s.ID, s.Name.Value, qb); err != nil {
return err
}
}
Expand Down
104 changes: 104 additions & 0 deletions pkg/studio/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package studio

import (
"testing"

"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/models/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func nameFilter(n string) *models.StudioFilterType {
return &models.StudioFilterType{
Name: &models.StringCriterionInput{
Value: n,
Modifier: models.CriterionModifierEquals,
},
}
}

func TestValidateName(t *testing.T) {
db := mocks.NewDatabase()

const (
name1 = "name 1"
newName = "new name"
)

existing1 := models.Studio{
ID: 1,
Name: name1,
}

pp := 1
findFilter := &models.FindFilterType{
PerPage: &pp,
}

db.Studio.On("Query", testCtx, nameFilter(name1), findFilter).Return([]*models.Studio{&existing1}, 1, nil)
db.Studio.On("Query", testCtx, mock.Anything, findFilter).Return(nil, 0, nil)

tests := []struct {
tName string
name string
want error
}{
{"missing name", "", ErrNameMissing},
{"new name", newName, nil},
{"existing name", name1, &NameExistsError{name1}},
}

for _, tt := range tests {
t.Run(tt.tName, func(t *testing.T) {
got := validateName(testCtx, 0, tt.name, db.Studio)
assert.Equal(t, tt.want, got)
})
}
}

func TestValidateUpdateName(t *testing.T) {
db := mocks.NewDatabase()

const (
name1 = "name 1"
name2 = "name 2"
newName = "new name"
)

existing1 := models.Studio{
ID: 1,
Name: name1,
}
existing2 := models.Studio{
ID: 2,
Name: name2,
}

pp := 1
findFilter := &models.FindFilterType{
PerPage: &pp,
}

db.Studio.On("Query", testCtx, nameFilter(name1), findFilter).Return([]*models.Studio{&existing1}, 1, nil)
db.Studio.On("Query", testCtx, nameFilter(name2), findFilter).Return([]*models.Studio{&existing2}, 2, nil)
db.Studio.On("Query", testCtx, mock.Anything, findFilter).Return(nil, 0, nil)

tests := []struct {
tName string
studio models.Studio
name string
want error
}{
{"missing name", existing1, "", ErrNameMissing},
{"same name", existing2, name2, nil},
{"new name", existing1, newName, nil},
}

for _, tt := range tests {
t.Run(tt.tName, func(t *testing.T) {
got := validateName(testCtx, tt.studio.ID, tt.name, db.Studio)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit 5cf28cf

Please sign in to comment.