From 318fd84b9d95d5547c3471289f4ca19e1c319cc0 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Thu, 30 Mar 2023 19:48:18 +0200 Subject: [PATCH 01/12] Add type-safety in filters --- .github/workflows/test.yml | 6 +- README.md | 43 +++++++++++++- filter.go | 7 ++- filter_test.go | 4 +- join.go | 4 +- operator.go | 72 ++++++++++++++++-------- operator_test.go | 71 +++++++++++++++--------- search.go | 6 +- search_test.go | 14 ++--- settings_test.go | 4 +- util.go | 111 +++++++++++++++++++++++++++++++++++++ util_test.go | 62 +++++++++++++++++++++ 12 files changed, 335 insertions(+), 69 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3fa06b6..1665567 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: [1.17, 1.18, 1.19] + go: ["1.17", "1.18", "1.19", "1.20"] steps: - uses: actions/checkout@v3 - uses: actions/setup-go@v3 @@ -23,7 +23,7 @@ jobs: - name: Run tests run: | go test -v -race -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... ./... - - if: ${{ matrix.go == 1.19 }} + - if: ${{ matrix.go == 1.20 }} uses: shogo82148/actions-goveralls@v1 with: path-to-profile: coverage.txt @@ -36,5 +36,5 @@ jobs: - name: Run lint uses: golangci/golangci-lint-action@v3 with: - version: v1.50 + version: v1.52 args: --timeout 5m diff --git a/README.md b/README.md index 85940bd..6bd1a70 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,7 @@ type MyModelWithStatus struct{ - Inputs are escaped to prevent SQL injections. - Fields are pre-processed and clients cannot request fields that don't exist. This prevents database errors. If a non-existing field is required, it is simply ignored. The same goes for sorts and joins. It is not possible to request a relation that doesn't exist. +- Type-safety: in the same field pre-processing, the broad type of the field is checked against the database type (based on the model definition). This prevents database errors if the input cannot be converted to the column's type. - Foreign keys are always selected in joins to ensure associations can be assigned to parent model. - **Be careful** with bidirectional relations (for example an article is written by a user, and a user can have many articles). If you enabled both your models to preload these relations, the client can request them with an infinite depth (`Articles.User.Articles.User...`). To prevent this, it is advised to use **the relation blacklist** or **IsFinal** on the deepest requestable models. See the settings section for more details. @@ -251,6 +252,27 @@ type MyModelWithStatus struct{ - Don't use `gorm.Model` and add the necessary fields manually. You get better control over json struct tags this way. - Use pointers for nullable relations and nullable fields that implement `sql.Scanner` (such as `null.Time`). +### Filter type + +For non-primitive types (such as `*null.Time`), you should always use the `filter_type` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. + +Available broad types are: +- `text` +- `bool` +- `int` +- `uint` +- `float` +- `time` + +**Example** +```go +type MyModel struct{ + ID uint + // ... + StartDate null.Time `filter_type:"time"` +} +``` + ### Static conditions If you want to add static conditions (not automatically defined by the library), it is advised to group them like so: @@ -279,10 +301,25 @@ import ( // ... filter.Operators["$cont"] = &filter.Operator{ - Function: func(tx *gorm.DB, filter *filter.Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB { + if dataType != schema.String { + return tx + } query := column + " LIKE ?" - value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" - return filter.Where(tx, query, value) + value := "%" + sqlutil.EscapeLike(f.Args[0]) + "%" + return f.Where(tx, query, value) + }, + RequiredArguments: 1, +} + +filter.Operators["$eq"] = &filter.Operator{ + Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB { + arg, ok := filter.ConvertToSafeType(f.Args[0], dataType) + if !ok { + return tx + } + query := fmt.Sprintf("%s = ?", column, op) + return f.Where(tx, query, arg) }, RequiredArguments: 1, } diff --git a/filter.go b/filter.go index 8af4eab..b697e29 100644 --- a/filter.go +++ b/filter.go @@ -46,7 +46,12 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) * } else { fieldExpr = table + "." + tx.Statement.Quote(field.DBName) } - return f.Operator.Function(tx, f, fieldExpr, field.DataType) + + dataType := getDataType(field) + if dataType == DataTypeUnsupported { + return tx + } + return f.Operator.Function(tx, f, fieldExpr, dataType) } return joinScope, conditionScope diff --git a/filter_test.go b/filter_test.go index 6037602..2cf3103 100644 --- a/filter_test.go +++ b/filter_test.go @@ -53,7 +53,7 @@ func TestFilterScope(t *testing.T) { schema := &schema.Schema{ DBNames: []string{"name"}, FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name", DBName: "name"}, + "name": {Name: "Name", DBName: "name", DataType: schema.String}, }, Table: "test_scope_models", } @@ -346,7 +346,7 @@ func TestFilterScopeWithJoinDontDuplicate(t *testing.T) { Expression: clause.Where{ Exprs: []clause.Expression{ clause.Expr{SQL: "`Relation`.`name` = ?", Vars: []interface{}{"val1"}}, - clause.Expr{SQL: "`Relation`.`id` > ?", Vars: []interface{}{"0"}}, + clause.Expr{SQL: "`Relation`.`id` > ?", Vars: []interface{}{uint64(0)}}, }, }, }, diff --git a/join.go b/join.go index 17f14f8..5b60dfc 100644 --- a/join.go +++ b/join.go @@ -157,7 +157,7 @@ func join(tx *gorm.DB, joinName string, sch *schema.Schema) *gorm.DB { Table: clause.Table{Name: sch.Table, Alias: relation.Name}, ON: clause.Where{Exprs: exprs}, } - if !joinExists(tx.Statement, j) && !findStatementJoin(tx.Statement, relation, &j) { + if !joinExists(tx.Statement, j) && !findStatementJoin(tx.Statement, &j) { joins = append(joins, j) } } @@ -201,7 +201,7 @@ func joinExists(stmt *gorm.Statement, join clause.Join) bool { // Removes this information from the join afterwards to avoid Gorm reprocessing it. // This is used to avoid duplicate joins that produce ambiguous column names and to // support computed columns. -func findStatementJoin(stmt *gorm.Statement, relation *schema.Relationship, join *clause.Join) bool { +func findStatementJoin(stmt *gorm.Statement, join *clause.Join) bool { for _, j := range stmt.Joins { if j.Name == join.Table.Alias { return true diff --git a/operator.go b/operator.go index b5ead2d..8b000cf 100644 --- a/operator.go +++ b/operator.go @@ -4,19 +4,21 @@ import ( "fmt" "gorm.io/gorm" - "gorm.io/gorm/schema" "goyave.dev/goyave/v4/util/sqlutil" ) // Operator used by filters to build the SQL query. // The operator function modifies the GORM statement (most of the time by adding // a WHERE condition) then returns the modified statement. +// // Operators may need arguments (e.g. "$eq", equals needs a value to compare the field to); // RequiredArguments define the minimum number of arguments a client must send in order to // use this operator in a filter. RequiredArguments is checked during Filter parsing. -// Operators may return the given tx without change if they don't support the given dataType. +// +// Operators may return the given tx without change if they don't support the given dataType or +// add a condition that will always be false. type Operator struct { - Function func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB + Function func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB RequiredArguments uint8 } @@ -30,7 +32,10 @@ var ( "$gte": {Function: basicComparison(">="), RequiredArguments: 1}, "$lte": {Function: basicComparison("<="), RequiredArguments: 1}, "$starts": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType != DataTypeText { + return tx.Where("FALSE") + } query := column + " LIKE ?" value := sqlutil.EscapeLike(filter.Args[0]) + "%" return filter.Where(tx, query, value) @@ -38,7 +43,10 @@ var ( RequiredArguments: 1, }, "$ends": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType != DataTypeText { + return tx.Where("FALSE") + } query := column + " LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) return filter.Where(tx, query, value) @@ -46,7 +54,10 @@ var ( RequiredArguments: 1, }, "$cont": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType != DataTypeText { + return tx.Where("FALSE") + } query := column + " LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" return filter.Where(tx, query, value) @@ -54,7 +65,10 @@ var ( RequiredArguments: 1, }, "$excl": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType != DataTypeText { + return tx.Where("FALSE") + } query := column + " NOT LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" return filter.Where(tx, query, value) @@ -64,55 +78,67 @@ var ( "$in": {Function: multiComparison("IN"), RequiredArguments: 1}, "$notin": {Function: multiComparison("NOT IN"), RequiredArguments: 1}, "$isnull": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return filter.Where(tx, column+" IS NULL") }, RequiredArguments: 0, }, "$istrue": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { - if dataType != schema.Bool { - return tx + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType != DataTypeBool { + return tx.Where("FALSE") } return filter.Where(tx, column+" IS TRUE") }, RequiredArguments: 0, }, "$isfalse": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { - if dataType != schema.Bool { - return tx + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType != DataTypeBool { + return tx.Where("FALSE") } return filter.Where(tx, column+" IS FALSE") }, RequiredArguments: 0, }, "$notnull": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return filter.Where(tx, column+" IS NOT NULL") }, RequiredArguments: 0, }, "$between": { - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + args, ok := ConvertArgsToSafeType(filter.Args[:2], dataType) + if !ok { + return tx.Where("FALSE") + } query := column + " BETWEEN ? AND ?" - return filter.Where(tx, query, filter.Args[0], filter.Args[1]) + return filter.Where(tx, query, args...) }, RequiredArguments: 2, }, } ) -func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { - return func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { +func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + arg, ok := ConvertToSafeType(filter.Args[0], dataType) + if !ok { + return tx.Where("FALSE") + } query := fmt.Sprintf("%s %s ?", column, op) - return filter.Where(tx, query, filter.Args[0]) + return filter.Where(tx, query, arg) } } -func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { - return func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { +func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + args, ok := ConvertArgsToSafeType(filter.Args, dataType) + if !ok { + return tx.Where("FALSE") + } query := fmt.Sprintf("%s %s ?", column, op) - return filter.Where(tx, query, filter.Args) + return filter.Where(tx, query, args) } } diff --git a/operator_test.go b/operator_test.go index 5a5bca6..8f2d709 100644 --- a/operator_test.go +++ b/operator_test.go @@ -5,12 +5,11 @@ import ( "github.com/stretchr/testify/assert" "gorm.io/gorm/clause" - "gorm.io/gorm/schema" ) func TestEquals(t *testing.T) { db := openDryRunDB(t) - db = Operators["$eq"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", schema.String) + db = Operators["$eq"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -27,7 +26,7 @@ func TestEquals(t *testing.T) { func TestNotEquals(t *testing.T) { db := openDryRunDB(t) - db = Operators["$ne"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", schema.String) + db = Operators["$ne"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -44,7 +43,7 @@ func TestNotEquals(t *testing.T) { func TestGreaterThan(t *testing.T) { db := openDryRunDB(t) - db = Operators["$gt"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", schema.String) + db = Operators["$gt"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -61,7 +60,7 @@ func TestGreaterThan(t *testing.T) { func TestLowerThan(t *testing.T) { db := openDryRunDB(t) - db = Operators["$lt"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", schema.String) + db = Operators["$lt"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -78,7 +77,7 @@ func TestLowerThan(t *testing.T) { func TestGreaterThanEqual(t *testing.T) { db := openDryRunDB(t) - db = Operators["$gte"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", schema.String) + db = Operators["$gte"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -95,7 +94,7 @@ func TestGreaterThanEqual(t *testing.T) { func TestLowerThanEqual(t *testing.T) { db := openDryRunDB(t) - db = Operators["$lte"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", schema.String) + db = Operators["$lte"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -112,7 +111,7 @@ func TestLowerThanEqual(t *testing.T) { func TestStarts(t *testing.T) { db := openDryRunDB(t) - db = Operators["$starts"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", schema.String) + db = Operators["$starts"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -129,7 +128,7 @@ func TestStarts(t *testing.T) { func TestEnds(t *testing.T) { db := openDryRunDB(t) - db = Operators["$ends"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", schema.String) + db = Operators["$ends"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -146,7 +145,7 @@ func TestEnds(t *testing.T) { func TestContains(t *testing.T) { db := openDryRunDB(t) - db = Operators["$cont"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", schema.String) + db = Operators["$cont"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -163,7 +162,7 @@ func TestContains(t *testing.T) { func TestNotContains(t *testing.T) { db := openDryRunDB(t) - db = Operators["$excl"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", schema.String) + db = Operators["$excl"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -180,14 +179,14 @@ func TestNotContains(t *testing.T) { func TestIn(t *testing.T) { db := openDryRunDB(t) - db = Operators["$in"].Function(db, &Filter{Field: "name", Args: []string{"val1", "val2"}}, "`test_models`.`name`", schema.String) + db = Operators["$in"].Function(db, &Filter{Field: "name", Args: []string{"val1", "val2"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { Name: "WHERE", Expression: clause.Where{ Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` IN ?", Vars: []interface{}{[]string{"val1", "val2"}}}, + clause.Expr{SQL: "`test_models`.`name` IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, }, }, }, @@ -197,14 +196,14 @@ func TestIn(t *testing.T) { func TestNotIn(t *testing.T) { db := openDryRunDB(t) - db = Operators["$notin"].Function(db, &Filter{Field: "name", Args: []string{"val1", "val2"}}, "`test_models`.`name`", schema.String) + db = Operators["$notin"].Function(db, &Filter{Field: "name", Args: []string{"val1", "val2"}}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { Name: "WHERE", Expression: clause.Where{ Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` NOT IN ?", Vars: []interface{}{[]string{"val1", "val2"}}}, + clause.Expr{SQL: "`test_models`.`name` NOT IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, }, }, }, @@ -214,7 +213,7 @@ func TestNotIn(t *testing.T) { func TestIsNull(t *testing.T) { db := openDryRunDB(t) - db = Operators["$isnull"].Function(db, &Filter{Field: "name"}, "`test_models`.`name`", schema.String) + db = Operators["$isnull"].Function(db, &Filter{Field: "name"}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -231,7 +230,7 @@ func TestIsNull(t *testing.T) { func TestNotNull(t *testing.T) { db := openDryRunDB(t) - db = Operators["$notnull"].Function(db, &Filter{Field: "name"}, "`test_models`.`name`", schema.String) + db = Operators["$notnull"].Function(db, &Filter{Field: "name"}, "`test_models`.`name`", DataTypeText) expected := map[string]clause.Clause{ "WHERE": { @@ -248,14 +247,14 @@ func TestNotNull(t *testing.T) { func TestBetween(t *testing.T) { db := openDryRunDB(t) - db = Operators["$between"].Function(db, &Filter{Field: "age", Args: []string{"18", "25"}}, "`test_models`.`age`", schema.Uint) + db = Operators["$between"].Function(db, &Filter{Field: "age", Args: []string{"18", "25"}}, "`test_models`.`age`", DataTypeUint) expected := map[string]clause.Clause{ "WHERE": { Name: "WHERE", Expression: clause.Where{ Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`age` BETWEEN ? AND ?", Vars: []interface{}{"18", "25"}}, + clause.Expr{SQL: "`test_models`.`age` BETWEEN ? AND ?", Vars: []interface{}{uint64(18), uint64(25)}}, }, }, }, @@ -265,7 +264,7 @@ func TestBetween(t *testing.T) { func TestIsTrue(t *testing.T) { db := openDryRunDB(t) - db = Operators["$istrue"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", schema.Bool) + db = Operators["$istrue"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeBool) expected := map[string]clause.Clause{ "WHERE": { @@ -280,13 +279,24 @@ func TestIsTrue(t *testing.T) { assert.Equal(t, expected, db.Statement.Clauses) db = openDryRunDB(t) - db = Operators["$istrue"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", schema.String) // Unsupported type - assert.Empty(t, db.Statement.Clauses) + db = Operators["$istrue"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeText) // Unsupported type + + expected = map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + } + assert.Equal(t, expected, db.Statement.Clauses) } func TestIsFalse(t *testing.T) { db := openDryRunDB(t) - db = Operators["$isfalse"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", schema.Bool) + db = Operators["$isfalse"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeBool) expected := map[string]clause.Clause{ "WHERE": { @@ -301,6 +311,17 @@ func TestIsFalse(t *testing.T) { assert.Equal(t, expected, db.Statement.Clauses) db = openDryRunDB(t) - db = Operators["$isfalse"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", schema.String) // Unsupported type - assert.Empty(t, db.Statement.Clauses) + db = Operators["$isfalse"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeText) // Unsupported type + + expected = map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + } + assert.Equal(t, expected, db.Statement.Clauses) } diff --git a/search.go b/search.go index 3785f34..0cf309d 100644 --- a/search.go +++ b/search.go @@ -57,7 +57,11 @@ func (s *Search) Scope(schema *schema.Schema) func(*gorm.DB) *gorm.DB { fieldExpr = table + "." + tx.Statement.Quote(f.DBName) } - searchQuery = s.Operator.Function(searchQuery, filter, fieldExpr, f.DataType) + dataType := getDataType(f) + if dataType == DataTypeUnsupported { + return tx + } + searchQuery = s.Operator.Function(searchQuery, filter, fieldExpr, dataType) } return tx.Where(searchQuery) diff --git a/search_test.go b/search_test.go index 1ad0f4a..5912ece 100644 --- a/search_test.go +++ b/search_test.go @@ -16,7 +16,7 @@ func TestSearchScope(t *testing.T) { Fields: []string{"name", "email"}, Query: "My Query", Operator: &Operator{ - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) }, RequiredArguments: 1, @@ -25,9 +25,9 @@ func TestSearchScope(t *testing.T) { schema := &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name", DBName: "name"}, - "email": {Name: "Email", DBName: "email"}, - "role": {Name: "Role", DBName: "role"}, + "name": {Name: "Name", DBName: "name", DataType: schema.String}, + "email": {Name: "Email", DBName: "email", DataType: schema.String}, + "role": {Name: "Role", DBName: "role", DataType: schema.String}, }, Table: "test_models", } @@ -80,7 +80,7 @@ func TestSearchScopeEmptyField(t *testing.T) { Fields: []string{}, Query: "My Query", Operator: &Operator{ - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) }, RequiredArguments: 1, @@ -123,7 +123,7 @@ func TestSeachScopeWithJoin(t *testing.T) { Fields: []string{"name", "Relation.name"}, Query: "My Query", Operator: &Operator{ - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) }, RequiredArguments: 1, @@ -232,7 +232,7 @@ func TestSeachScopeWithJoinNestedRelation(t *testing.T) { Fields: []string{"name", "Relation.NestedRelation.field"}, Query: "My Query", Operator: &Operator{ - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) }, RequiredArguments: 1, diff --git a/settings_test.go b/settings_test.go index 138b1c4..d56f601 100644 --- a/settings_test.go +++ b/settings_test.go @@ -107,7 +107,7 @@ func TestScope(t *testing.T) { paginator, db := prepareTestScope(t, &Settings{ FieldsSearch: []string{"email"}, SearchOperator: &Operator{ - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) }, RequiredArguments: 1, @@ -200,7 +200,7 @@ func TestScopeUnpaginated(t *testing.T) { results, db := prepareTestScopeUnpaginated(t, &Settings{ FieldsSearch: []string{"email"}, SearchOperator: &Operator{ - Function: func(tx *gorm.DB, filter *Filter, column string, dataType schema.DataType) *gorm.DB { + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) }, RequiredArguments: 1, diff --git a/util.go b/util.go index c8b7340..63d23d1 100644 --- a/util.go +++ b/util.go @@ -1,10 +1,30 @@ package filter import ( + "strconv" + "strings" + "time" + "gorm.io/gorm/schema" "goyave.dev/goyave/v4/util/sliceutil" ) +// DataType is determined by the `filter_type` struct tag (see `DataType` for available options). +// If not given, uses GORM's general DataType. Raw database data types are not supported so it is +// recommended to always specify a `filter_type` in this scenario. +type DataType string + +// Supported DataTypes +const ( + DataTypeText DataType = "text" + DataTypeBool DataType = "bool" + DataTypeInt DataType = "int" + DataTypeUint DataType = "uint" + DataTypeFloat DataType = "float" + DataTypeTime DataType = "time" + DataTypeUnsupported DataType = "unsupported" +) + func cleanColumns(sch *schema.Schema, columns []string, blacklist []string) []*schema.Field { fields := make([]*schema.Field, 0, len(columns)) for _, c := range columns { @@ -47,3 +67,94 @@ func columnsContain(fields []*schema.Field, field *schema.Field) bool { } return false } + +func getDataType(field *schema.Field) DataType { + fromTag := DataType(strings.ToLower(field.Tag.Get("filter_type"))) + switch fromTag { + case DataTypeText, DataTypeBool, DataTypeFloat, DataTypeInt, DataTypeUint, DataTypeTime: + return fromTag + default: + switch field.DataType { + case schema.String: + return DataTypeText + case schema.Bool: + return DataTypeBool + case schema.Float: + return DataTypeFloat + case schema.Int: + return DataTypeInt + case schema.Uint: + return DataTypeUint + case schema.Time: + return DataTypeTime + } + } + return DataTypeUnsupported +} + +// ConvertToSafeType convert the string argument to a safe type that +// matches the column's data type. Returns false if the input could not +// be converted. +func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { // TODO test this + test when datatype doesn't match + switch dataType { + case DataTypeText: + return arg, true + case DataTypeBool: + switch arg { + case "1", "on", "true", "yes": + return true, true + case "0", "off", "false", "no": + return false, true + } + return nil, false + case DataTypeFloat: + i, err := strconv.ParseFloat(arg, 64) + if err != nil { + return nil, false + } + return i, true + case DataTypeInt: + i, err := strconv.ParseInt(arg, 10, 64) // TODO check it works on smallint + if err != nil { + return nil, false + } + return i, true + case DataTypeUint: + i, err := strconv.ParseUint(arg, 10, 64) + if err != nil { + return nil, false + } + return i, true + case DataTypeTime: + if validateTime(arg) { + return arg, true + } + } + return nil, false +} + +func validateTime(timeStr string) bool { + for _, format := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02"} { + _, err := time.Parse(format, timeStr) + if err == nil { + return true + } + } + + return false +} + +// ConvertArgsToSafeType converts a slice of string arguments to safe type +// that matches the column's data type in the same way as `ConvertToSafeType`. +// If any of the values in the given slice could not be converted, returns false. +func ConvertArgsToSafeType(args []string, dataType DataType) ([]interface{}, bool) { + result := make([]interface{}, 0, len(args)) + for _, arg := range args { + a, ok := ConvertToSafeType(arg, dataType) + if !ok { + return nil, false + } + result = append(result, a) + } + return result, true +} diff --git a/util_test.go b/util_test.go index 1a1a322..34e11bf 100644 --- a/util_test.go +++ b/util_test.go @@ -1,6 +1,7 @@ package filter import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -56,3 +57,64 @@ func TestAddForeignKeys(t *testing.T) { fields = addForeignKeys(schema, fields) assert.ElementsMatch(t, []string{"id", "child_id", "parent_id"}, fields) } + +func TestConvertToSafeType(t *testing.T) { + cases := []struct { + want interface{} + dataType DataType + value string + wantOk bool + }{ + // String + {value: "string", dataType: DataTypeText, want: "string", wantOk: true}, + + // Bool + {value: "1", dataType: DataTypeBool, want: true, wantOk: true}, + {value: "on", dataType: DataTypeBool, want: true, wantOk: true}, + {value: "true", dataType: DataTypeBool, want: true, wantOk: true}, + {value: "yes", dataType: DataTypeBool, want: true, wantOk: true}, + {value: "0", dataType: DataTypeBool, want: false, wantOk: true}, + {value: "off", dataType: DataTypeBool, want: false, wantOk: true}, + {value: "false", dataType: DataTypeBool, want: false, wantOk: true}, + {value: "no", dataType: DataTypeBool, want: false, wantOk: true}, + {value: "not a bool", dataType: DataTypeBool, want: nil, wantOk: false}, + + // Float + {value: "1", dataType: DataTypeFloat, want: 1.0, wantOk: true}, + {value: "1.0", dataType: DataTypeFloat, want: 1.0, wantOk: true}, + {value: "1.23", dataType: DataTypeFloat, want: 1.23, wantOk: true}, + {value: "string", dataType: DataTypeFloat, want: nil, wantOk: false}, + + // Int + {value: "1", dataType: DataTypeInt, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeInt, want: nil, wantOk: false}, + {value: "string", dataType: DataTypeInt, want: nil, wantOk: false}, + + // Uint + {value: "1", dataType: DataTypeUint, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUint, want: nil, wantOk: false}, + {value: "1.23", dataType: DataTypeUint, want: nil, wantOk: false}, + {value: "string", dataType: DataTypeUint, want: nil, wantOk: false}, + + // Time + {value: "2023-03-23", dataType: DataTypeTime, want: "2023-03-23", wantOk: true}, + {value: "2023-03-23 12:13:24", dataType: DataTypeTime, want: "2023-03-23 12:13:24", wantOk: true}, + {value: "2023-03-23T12:13:24Z", dataType: DataTypeTime, want: "2023-03-23T12:13:24Z", wantOk: true}, + {value: "2023-03-23T12:13:24", dataType: DataTypeTime, want: nil, wantOk: false}, + {value: "not a date", dataType: DataTypeTime, want: nil, wantOk: false}, + {value: "1234", dataType: DataTypeTime, want: nil, wantOk: false}, + + // Unsupported + {value: "1234", dataType: DataTypeUnsupported, want: nil, wantOk: false}, + } + + for _, c := range cases { + c := c + t.Run(fmt.Sprintf("%s_%s", c.value, c.dataType), func(t *testing.T) { + val, ok := ConvertToSafeType(c.value, c.dataType) + assert.Equal(t, c.want, val) + assert.Equal(t, c.wantOk, ok) + }) + } +} From 254b8004fa708540a7b3b3925031009992d94846 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Fri, 31 Mar 2023 16:40:39 +0200 Subject: [PATCH 02/12] Type safety: add support for array types --- README.md | 24 ++++++++++++++---------- filter.go | 2 +- operator.go | 9 +++++++++ search.go | 2 +- util.go | 48 +++++++++++++++++++++++++++++++----------------- util_test.go | 29 +++++++++++++++++++++++++++++ 6 files changed, 85 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 6bd1a70..614813b 100644 --- a/README.md +++ b/README.md @@ -254,22 +254,23 @@ type MyModelWithStatus struct{ ### Filter type -For non-primitive types (such as `*null.Time`), you should always use the `filter_type` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. +For non-primitive types (such as `*null.Time`), you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. Available broad types are: -- `text` -- `bool` -- `int` -- `uint` -- `float` -- `time` +- `text` / `text[]` +- `bool` / `bool[]` +- `int` / `int[]` +- `uint` / `uint[]` +- `float` / `float[]` +- `time` / `time[]` +- `-`: unsupported data type. Fields tagged with `-` will be ignored in filters and search: no condition will be added to the `WHERE` clause. **Example** ```go type MyModel struct{ ID uint // ... - StartDate null.Time `filter_type:"time"` + StartDate null.Time `filterType:"time"` } ``` @@ -302,8 +303,8 @@ import ( filter.Operators["$cont"] = &filter.Operator{ Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB { - if dataType != schema.String { - return tx + if dataType != schema.String || dataType.IsArray() { + return tx.Where("FALSE") } query := column + " LIKE ?" value := "%" + sqlutil.EscapeLike(f.Args[0]) + "%" @@ -314,6 +315,9 @@ filter.Operators["$cont"] = &filter.Operator{ filter.Operators["$eq"] = &filter.Operator{ Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB { + if dataType.IsArray() { + return tx.Where("FALSE") + } arg, ok := filter.ConvertToSafeType(f.Args[0], dataType) if !ok { return tx diff --git a/filter.go b/filter.go index b697e29..34247ed 100644 --- a/filter.go +++ b/filter.go @@ -48,7 +48,7 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) * } dataType := getDataType(field) - if dataType == DataTypeUnsupported { + if dataType == DataTypeUnsupported { // TODO test this return tx } return f.Operator.Function(tx, f, fieldExpr, dataType) diff --git a/operator.go b/operator.go index 8b000cf..c4ea834 100644 --- a/operator.go +++ b/operator.go @@ -109,6 +109,9 @@ var ( }, "$between": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType.IsArray() { + return tx.Where("FALSE") + } args, ok := ConvertArgsToSafeType(filter.Args[:2], dataType) if !ok { return tx.Where("FALSE") @@ -123,6 +126,9 @@ var ( func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType.IsArray() { + return tx.Where("FALSE") + } arg, ok := ConvertToSafeType(filter.Args[0], dataType) if !ok { return tx.Where("FALSE") @@ -134,6 +140,9 @@ func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + if dataType.IsArray() { + return tx.Where("FALSE") + } args, ok := ConvertArgsToSafeType(filter.Args, dataType) if !ok { return tx.Where("FALSE") diff --git a/search.go b/search.go index 0cf309d..a9dc80c 100644 --- a/search.go +++ b/search.go @@ -58,7 +58,7 @@ func (s *Search) Scope(schema *schema.Schema) func(*gorm.DB) *gorm.DB { } dataType := getDataType(f) - if dataType == DataTypeUnsupported { + if dataType == DataTypeUnsupported { // TODO test this return tx } searchQuery = s.Operator.Function(searchQuery, filter, fieldExpr, dataType) diff --git a/util.go b/util.go index 63d23d1..272093b 100644 --- a/util.go +++ b/util.go @@ -9,20 +9,33 @@ import ( "goyave.dev/goyave/v4/util/sliceutil" ) -// DataType is determined by the `filter_type` struct tag (see `DataType` for available options). +// DataType is determined by the `filterType` struct tag (see `DataType` for available options). // If not given, uses GORM's general DataType. Raw database data types are not supported so it is -// recommended to always specify a `filter_type` in this scenario. +// recommended to always specify a `filterType` in this scenario. type DataType string +// IsArray returns true if this data type is an array. +func (d DataType) IsArray() bool { + return strings.HasSuffix(string(d), "[]") +} + // Supported DataTypes const ( - DataTypeText DataType = "text" - DataTypeBool DataType = "bool" - DataTypeInt DataType = "int" - DataTypeUint DataType = "uint" - DataTypeFloat DataType = "float" - DataTypeTime DataType = "time" - DataTypeUnsupported DataType = "unsupported" + DataTypeText DataType = "text" + DataTypeTextArray DataType = "text[]" + DataTypeBool DataType = "bool" + DataTypeBoolArray DataType = "bool[]" + DataTypeInt DataType = "int" + DataTypeIntArray DataType = "int[]" + DataTypeUint DataType = "uint" + DataTypeUintArray DataType = "uint[]" + DataTypeFloat DataType = "float" + DataTypeFloatArray DataType = "float[]" + DataTypeTime DataType = "time" + DataTypeTimeArray DataType = "time[]" + + // DataTypeUnsupported all fields with this tag will be ignored in filters and search. + DataTypeUnsupported DataType = "-" ) func cleanColumns(sch *schema.Schema, columns []string, blacklist []string) []*schema.Field { @@ -69,9 +82,10 @@ func columnsContain(fields []*schema.Field, field *schema.Field) bool { } func getDataType(field *schema.Field) DataType { - fromTag := DataType(strings.ToLower(field.Tag.Get("filter_type"))) + fromTag := DataType(strings.ToLower(field.Tag.Get("filterType"))) switch fromTag { - case DataTypeText, DataTypeBool, DataTypeFloat, DataTypeInt, DataTypeUint, DataTypeTime: + case DataTypeText, DataTypeBool, DataTypeFloat, DataTypeInt, DataTypeUint, DataTypeTime, + DataTypeTextArray, DataTypeBoolArray, DataTypeFloatArray, DataTypeIntArray, DataTypeUintArray, DataTypeTimeArray: return fromTag default: switch field.DataType { @@ -97,9 +111,9 @@ func getDataType(field *schema.Field) DataType { // be converted. func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { // TODO test this + test when datatype doesn't match switch dataType { - case DataTypeText: + case DataTypeText, DataTypeTextArray: return arg, true - case DataTypeBool: + case DataTypeBool, DataTypeBoolArray: switch arg { case "1", "on", "true", "yes": return true, true @@ -107,25 +121,25 @@ func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { // T return false, true } return nil, false - case DataTypeFloat: + case DataTypeFloat, DataTypeFloatArray: i, err := strconv.ParseFloat(arg, 64) if err != nil { return nil, false } return i, true - case DataTypeInt: + case DataTypeInt, DataTypeIntArray: i, err := strconv.ParseInt(arg, 10, 64) // TODO check it works on smallint if err != nil { return nil, false } return i, true - case DataTypeUint: + case DataTypeUint, DataTypeUintArray: i, err := strconv.ParseUint(arg, 10, 64) if err != nil { return nil, false } return i, true - case DataTypeTime: + case DataTypeTime, DataTypeTimeArray: if validateTime(arg) { return arg, true } diff --git a/util_test.go b/util_test.go index 34e11bf..e0e7a84 100644 --- a/util_test.go +++ b/util_test.go @@ -67,6 +67,7 @@ func TestConvertToSafeType(t *testing.T) { }{ // String {value: "string", dataType: DataTypeText, want: "string", wantOk: true}, + {value: "string", dataType: DataTypeTextArray, want: "string", wantOk: true}, // Bool {value: "1", dataType: DataTypeBool, want: true, wantOk: true}, @@ -78,24 +79,45 @@ func TestConvertToSafeType(t *testing.T) { {value: "false", dataType: DataTypeBool, want: false, wantOk: true}, {value: "no", dataType: DataTypeBool, want: false, wantOk: true}, {value: "not a bool", dataType: DataTypeBool, want: nil, wantOk: false}, + {value: "1", dataType: DataTypeBoolArray, want: true, wantOk: true}, + {value: "on", dataType: DataTypeBoolArray, want: true, wantOk: true}, + {value: "true", dataType: DataTypeBoolArray, want: true, wantOk: true}, + {value: "yes", dataType: DataTypeBoolArray, want: true, wantOk: true}, + {value: "0", dataType: DataTypeBoolArray, want: false, wantOk: true}, + {value: "off", dataType: DataTypeBoolArray, want: false, wantOk: true}, + {value: "false", dataType: DataTypeBoolArray, want: false, wantOk: true}, + {value: "no", dataType: DataTypeBoolArray, want: false, wantOk: true}, + {value: "not a bool", dataType: DataTypeBoolArray, want: nil, wantOk: false}, // Float {value: "1", dataType: DataTypeFloat, want: 1.0, wantOk: true}, {value: "1.0", dataType: DataTypeFloat, want: 1.0, wantOk: true}, {value: "1.23", dataType: DataTypeFloat, want: 1.23, wantOk: true}, {value: "string", dataType: DataTypeFloat, want: nil, wantOk: false}, + {value: "1", dataType: DataTypeFloatArray, want: 1.0, wantOk: true}, + {value: "1.0", dataType: DataTypeFloatArray, want: 1.0, wantOk: true}, + {value: "1.23", dataType: DataTypeFloatArray, want: 1.23, wantOk: true}, + {value: "string", dataType: DataTypeFloatArray, want: nil, wantOk: false}, // Int {value: "1", dataType: DataTypeInt, want: int64(1), wantOk: true}, {value: "-2", dataType: DataTypeInt, want: int64(-2), wantOk: true}, {value: "1.23", dataType: DataTypeInt, want: nil, wantOk: false}, {value: "string", dataType: DataTypeInt, want: nil, wantOk: false}, + {value: "1", dataType: DataTypeIntArray, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeIntArray, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeIntArray, want: nil, wantOk: false}, + {value: "string", dataType: DataTypeIntArray, want: nil, wantOk: false}, // Uint {value: "1", dataType: DataTypeUint, want: uint64(1), wantOk: true}, {value: "-2", dataType: DataTypeUint, want: nil, wantOk: false}, {value: "1.23", dataType: DataTypeUint, want: nil, wantOk: false}, {value: "string", dataType: DataTypeUint, want: nil, wantOk: false}, + {value: "1", dataType: DataTypeUintArray, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUintArray, want: nil, wantOk: false}, + {value: "1.23", dataType: DataTypeUintArray, want: nil, wantOk: false}, + {value: "string", dataType: DataTypeUintArray, want: nil, wantOk: false}, // Time {value: "2023-03-23", dataType: DataTypeTime, want: "2023-03-23", wantOk: true}, @@ -104,9 +126,16 @@ func TestConvertToSafeType(t *testing.T) { {value: "2023-03-23T12:13:24", dataType: DataTypeTime, want: nil, wantOk: false}, {value: "not a date", dataType: DataTypeTime, want: nil, wantOk: false}, {value: "1234", dataType: DataTypeTime, want: nil, wantOk: false}, + {value: "2023-03-23", dataType: DataTypeTimeArray, want: "2023-03-23", wantOk: true}, + {value: "2023-03-23 12:13:24", dataType: DataTypeTimeArray, want: "2023-03-23 12:13:24", wantOk: true}, + {value: "2023-03-23T12:13:24Z", dataType: DataTypeTimeArray, want: "2023-03-23T12:13:24Z", wantOk: true}, + {value: "2023-03-23T12:13:24", dataType: DataTypeTimeArray, want: nil, wantOk: false}, + {value: "not a date", dataType: DataTypeTimeArray, want: nil, wantOk: false}, + {value: "1234", dataType: DataTypeTimeArray, want: nil, wantOk: false}, // Unsupported {value: "1234", dataType: DataTypeUnsupported, want: nil, wantOk: false}, + {value: "1234", dataType: "CHARACTER VARYING(255)[]", want: nil, wantOk: false}, } for _, c := range cases { From 4e66731f0f029128b00bfd2d84f6d84224db0090 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Fri, 31 Mar 2023 17:19:56 +0200 Subject: [PATCH 03/12] Test type-safety WIP --- README.md | 2 + filter.go | 13 ++++-- filter_test.go | 55 +++++++++++++++++++++++ search.go | 8 ++-- search_test.go | 52 ++++++++++++++++++++++ util.go | 7 +-- util_test.go | 117 +++++++++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 243 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 614813b..428dea1 100644 --- a/README.md +++ b/README.md @@ -265,6 +265,8 @@ Available broad types are: - `time` / `time[]` - `-`: unsupported data type. Fields tagged with `-` will be ignored in filters and search: no condition will be added to the `WHERE` clause. +If not provided, the type will be determined from GORM's data type (defined by the `gorm:"type:..."` tag). If GORM's data type is a database type or a type that is not directly supported by this library, the type will fall back to `-` (unsupported). + **Example** ```go type MyModel struct{ diff --git a/filter.go b/filter.go index 34247ed..069f623 100644 --- a/filter.go +++ b/filter.go @@ -24,7 +24,12 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) * return nil, nil } + dataType := getDataType(field) + joinScope := func(tx *gorm.DB) *gorm.DB { + if dataType == DataTypeUnsupported { + return tx + } if joinName != "" { if err := tx.Statement.Parse(tx.Statement.Model); err != nil { tx.AddError(err) @@ -39,6 +44,10 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) * computed := field.StructField.Tag.Get("computed") conditionScope := func(tx *gorm.DB) *gorm.DB { + if dataType == DataTypeUnsupported { + return tx + } + table := tx.Statement.Quote(tableFromJoinName(s.Table, joinName)) var fieldExpr string if computed != "" { @@ -47,10 +56,6 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) * fieldExpr = table + "." + tx.Statement.Quote(field.DBName) } - dataType := getDataType(field) - if dataType == DataTypeUnsupported { // TODO test this - return tx - } return f.Operator.Function(tx, f, fieldExpr, dataType) } diff --git a/filter_test.go b/filter_test.go index 2cf3103..079f553 100644 --- a/filter_test.go +++ b/filter_test.go @@ -112,6 +112,18 @@ type FilterTestModel struct { ID uint } +type FilterTestRelationUnsupported struct { + Name string `filterType:"-"` + ID uint + ParentID uint +} + +type FilterTestModelUnsupported struct { + Relation *FilterTestRelationUnsupported `gorm:"foreignKey:ParentID"` + Name string + ID uint +} + func TestFilterScopeWithJoin(t *testing.T) { db := openDryRunDB(t) filter := &Filter{Field: "Relation.name", Args: []string{"val1"}, Operator: Operators["$eq"]} @@ -608,3 +620,46 @@ func TestFilterScopeComputedRelation(t *testing.T) { } assert.Equal(t, expected, db.Statement.Clauses) } + +func TestFilterScopeWithUnsupportedDataType(t *testing.T) { + db := openDryRunDB(t) + filter := &Filter{Field: "name", Args: []string{"val1"}, Operator: Operators["$eq"]} + schema := &schema.Schema{ + DBNames: []string{"name"}, + FieldsByDBName: map[string]*schema.Field{ + "name": {Name: "Name", DBName: "name", DataType: "CHARACTER VARYING(255)"}, + }, + Table: "test_scope_models", + } + + results := []map[string]interface{}{} + db = db.Scopes(filter.Scope(&Settings{}, schema)).Find(results) + expected := map[string]clause.Clause{} + assert.Equal(t, expected, db.Statement.Clauses) +} + +func TestFilterScopeWithJoinedUnsupportedDataType(t *testing.T) { + db := openDryRunDB(t) + filter := &Filter{Field: "Relation.name", Args: []string{"val1"}, Operator: Operators["$eq"]} + + results := []*FilterTestModelUnsupported{} + schema, err := parseModel(db, &results) + if !assert.Nil(t, err) { + return + } + + db.DryRun = true + db = db.Model(&results).Scopes(filter.Scope(&Settings{}, schema)).Find(&results) + expected := map[string]clause.Clause{ + "FROM": { + Name: "FROM", + Expression: clause.From{}, + }, + "SELECT": { + Name: "SELECT", + Expression: clause.Select{}, + }, + } + assert.Equal(t, expected, db.Statement.Clauses) + assert.Nil(t, db.Error) +} diff --git a/search.go b/search.go index a9dc80c..fb6b8cb 100644 --- a/search.go +++ b/search.go @@ -31,6 +31,10 @@ func (s *Search) Scope(schema *schema.Schema) func(*gorm.DB) *gorm.DB { if f == nil { continue } + dataType := getDataType(f) + if dataType == DataTypeUnsupported { + continue + } if joinName != "" { if err := tx.Statement.Parse(tx.Statement.Model); err != nil { @@ -57,10 +61,6 @@ func (s *Search) Scope(schema *schema.Schema) func(*gorm.DB) *gorm.DB { fieldExpr = table + "." + tx.Statement.Quote(f.DBName) } - dataType := getDataType(f) - if dataType == DataTypeUnsupported { // TODO test this - return tx - } searchQuery = s.Operator.Function(searchQuery, filter, fieldExpr, dataType) } diff --git a/search_test.go b/search_test.go index 5912ece..9b7b1b6 100644 --- a/search_test.go +++ b/search_test.go @@ -438,3 +438,55 @@ func TestSearchScopeComputed(t *testing.T) { } assert.Equal(t, expected, db.Statement.Clauses) } + +func TestSearchScopeWithUnsupportedDataType(t *testing.T) { + db := openDryRunDB(t) + search := &Search{ + Fields: []string{"name", "email"}, + Query: "My Query", + Operator: &Operator{ + Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { + return tx.Or(fmt.Sprintf("%s LIKE (?)", column), filter.Args[0]) + }, + RequiredArguments: 1, + }, + } + + schema := &schema.Schema{ + FieldsByDBName: map[string]*schema.Field{ + "name": {Name: "Name", DBName: "name", DataType: schema.String}, + "email": {Name: "Email", DBName: "email", DataType: "CHARACTER VARYING(255)"}, + "role": {Name: "Role", DBName: "role", DataType: schema.String}, + }, + Table: "test_models", + } + + db = db.Scopes(search.Scope(schema)).Table("table").Find(nil) + expected := map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.AndConditions{ + Exprs: []clause.Expression{ + clause.Expr{ + SQL: "`test_models`.`name` LIKE (?)", + Vars: []interface{}{"My Query"}, + WithoutParentheses: false, + }, + }, + }, + }, + }, + }, + "FROM": { + Name: "FROM", + Expression: clause.From{}, + }, + "SELECT": { + Name: "SELECT", + Expression: clause.Select{}, + }, + } + assert.Equal(t, expected, db.Statement.Clauses) +} diff --git a/util.go b/util.go index 272093b..6dc38b3 100644 --- a/util.go +++ b/util.go @@ -85,9 +85,10 @@ func getDataType(field *schema.Field) DataType { fromTag := DataType(strings.ToLower(field.Tag.Get("filterType"))) switch fromTag { case DataTypeText, DataTypeBool, DataTypeFloat, DataTypeInt, DataTypeUint, DataTypeTime, - DataTypeTextArray, DataTypeBoolArray, DataTypeFloatArray, DataTypeIntArray, DataTypeUintArray, DataTypeTimeArray: + DataTypeTextArray, DataTypeBoolArray, DataTypeFloatArray, DataTypeIntArray, DataTypeUintArray, DataTypeTimeArray, + DataTypeUnsupported: return fromTag - default: + case "": switch field.DataType { case schema.String: return DataTypeText @@ -109,7 +110,7 @@ func getDataType(field *schema.Field) DataType { // ConvertToSafeType convert the string argument to a safe type that // matches the column's data type. Returns false if the input could not // be converted. -func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { // TODO test this + test when datatype doesn't match +func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { switch dataType { case DataTypeText, DataTypeTextArray: return arg, true diff --git a/util_test.go b/util_test.go index e0e7a84..b003188 100644 --- a/util_test.go +++ b/util_test.go @@ -147,3 +147,120 @@ func TestConvertToSafeType(t *testing.T) { }) } } + +func TestConvertArgsToSafeType(t *testing.T) { + + // No need for exhaustive testing here since it's already done by TestConvertToSafeType + cases := []struct { + want interface{} + dataType DataType + value []string + wantOk bool + }{ + {value: []string{"a", "b"}, dataType: DataTypeText, want: []interface{}{"a", "b"}, wantOk: true}, + {value: []string{"3", "4"}, dataType: DataTypeInt, want: []interface{}{int64(3), int64(4)}, wantOk: true}, + {value: []string{"a", "2"}, dataType: DataTypeInt, want: []interface{}(nil), wantOk: false}, + } + + for _, c := range cases { + c := c + t.Run(fmt.Sprintf("%s_%s", c.value, c.dataType), func(t *testing.T) { + val, ok := ConvertArgsToSafeType(c.value, c.dataType) + assert.Equal(t, c.want, val) + assert.Equal(t, c.wantOk, ok) + }) + } +} + +func TestGetDataType(t *testing.T) { + cases := []struct { + desc string + model interface{} + want DataType + }{ + {desc: "no specified type", model: struct{ Field string }{}, want: DataTypeText}, + {desc: "gorm type string", model: struct { + Field string `gorm:"type:string"` + }{}, want: DataTypeText}, + {desc: "gorm type bool", model: struct { + Field string `gorm:"type:bool"` + }{}, want: DataTypeBool}, + {desc: "gorm type int", model: struct { + Field string `gorm:"type:int"` + }{}, want: DataTypeInt}, + {desc: "gorm type uint", model: struct { + Field string `gorm:"type:uint"` + }{}, want: DataTypeUint}, + {desc: "gorm type float", model: struct { + Field string `gorm:"type:float"` + }{}, want: DataTypeFloat}, + {desc: "gorm type time", model: struct { + Field string `gorm:"type:time"` + }{}, want: DataTypeTime}, + {desc: "gorm type bytes", model: struct { + Field string `gorm:"type:bytes"` + }{}, want: DataTypeUnsupported}, + {desc: "gorm type custom", model: struct { + Field string `gorm:"type:CHARACTER VARYING(255)"` + }{}, want: DataTypeUnsupported}, + + {desc: "filter type unsupported", model: struct { + Field string `filterType:"-"` + }{}, want: DataTypeUnsupported}, + {desc: "filter type invalid", model: struct { + Field string `filterType:"invalid"` + }{}, want: DataTypeUnsupported}, + {desc: "filter type text", model: struct { + Field string `filterType:"text"` + }{}, want: DataTypeText}, + {desc: "filter type text array", model: struct { + Field string `filterType:"text[]"` + }{}, want: DataTypeTextArray}, + {desc: "filter type bool", model: struct { + Field string `filterType:"bool"` + }{}, want: DataTypeBool}, + {desc: "filter type bool array", model: struct { + Field string `filterType:"bool[]"` + }{}, want: DataTypeBoolArray}, + {desc: "filter type float", model: struct { + Field string `filterType:"float"` + }{}, want: DataTypeFloat}, + {desc: "filter type float array", model: struct { + Field string `filterType:"float[]"` + }{}, want: DataTypeFloatArray}, + {desc: "filter type int", model: struct { + Field string `filterType:"int"` + }{}, want: DataTypeInt}, + {desc: "filter type int array", model: struct { + Field string `filterType:"int[]"` + }{}, want: DataTypeIntArray}, + {desc: "filter type uint", model: struct { + Field string `filterType:"uint"` + }{}, want: DataTypeUint}, + {desc: "filter type uint array", model: struct { + Field string `filterType:"uint[]"` + }{}, want: DataTypeUintArray}, + {desc: "filter type time", model: struct { + Field string `filterType:"time"` + }{}, want: DataTypeTime}, + {desc: "filter type time array", model: struct { + Field string `filterType:"time[]"` + }{}, want: DataTypeTimeArray}, + + {desc: "filter type has priority over gorm type", model: struct { + Field string `gorm:"type:CHARACTER VARYING(255)" filterType:"text"` + }{}, want: DataTypeText}, + } + + for _, c := range cases { + c := c + t.Run(c.desc, func(t *testing.T) { + model, err := parseModel(openDryRunDB(t), c.model) + if !assert.NoError(t, err) { + return + } + + getDataType(model.Fields[0]) + }) + } +} From 3c559eeb33fb8743625901670cd055be30f80fa4 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Tue, 4 Apr 2023 17:19:14 +0200 Subject: [PATCH 04/12] Rewrite operator tests --- operator_test.go | 1081 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 887 insertions(+), 194 deletions(-) diff --git a/operator_test.go b/operator_test.go index 8f2d709..363f962 100644 --- a/operator_test.go +++ b/operator_test.go @@ -7,321 +7,1014 @@ import ( "gorm.io/gorm/clause" ) -func TestEquals(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$eq"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) +type operatorTestCase struct { + filter *Filter + want map[string]clause.Clause + desc string + op string + column string + dataType DataType +} - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` = ?", Vars: []interface{}{"test"}}, +func TestEquals(t *testing.T) { + cases := []operatorTestCase{ + { + desc: "ok", + op: "$eq", + filter: &Filter{Field: "name", Args: []string{"test"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` = ?", Vars: []interface{}{"test"}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$eq", + filter: &Filter{Field: "name", Args: []string{"test"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$eq", + filter: &Filter{Field: "age", Args: []string{"test"}}, + column: "`test_models`.`age`", + dataType: DataTypeFloat, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestNotEquals(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$ne"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` <> ?", Vars: []interface{}{"test"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$ne", + filter: &Filter{Field: "name", Args: []string{"test"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` <> ?", Vars: []interface{}{"test"}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$ne", + filter: &Filter{Field: "name", Args: []string{"test"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$ne", + filter: &Filter{Field: "age", Args: []string{"test"}}, + column: "`test_models`.`age`", + dataType: DataTypeFloat, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestGreaterThan(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$gt"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`age` > ?", Vars: []interface{}{"18"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$gt", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`age` > ?", Vars: []interface{}{int64(18)}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$gt", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeIntArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$gt", + filter: &Filter{Field: "age", Args: []string{"test"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestLowerThan(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$lt"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`age` < ?", Vars: []interface{}{"18"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$lt", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`age` < ?", Vars: []interface{}{int64(18)}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$lt", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeIntArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$lt", + filter: &Filter{Field: "age", Args: []string{"test"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestGreaterThanEqual(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$gte"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`age` >= ?", Vars: []interface{}{"18"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$gte", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`age` >= ?", Vars: []interface{}{int64(18)}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$gte", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeIntArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$gte", + filter: &Filter{Field: "age", Args: []string{"test"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestLowerThanEqual(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$lte"].Function(db, &Filter{Field: "age", Args: []string{"18"}}, "`test_models`.`age`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`age` <= ?", Vars: []interface{}{"18"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$lte", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`age` <= ?", Vars: []interface{}{int64(18)}}, + }, + }, }, }, }, + { + desc: "cannot_compare_array", + op: "$lte", + filter: &Filter{Field: "age", Args: []string{"18"}}, + column: "`test_models`.`age`", + dataType: DataTypeIntArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$lte", + filter: &Filter{Field: "age", Args: []string{"test"}}, + column: "`test_models`.`age`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) } - assert.Equal(t, expected, db.Statement.Clauses) } func TestStarts(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$starts"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` LIKE ?", Vars: []interface{}{"test%"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$starts", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` LIKE ?", Vars: []interface{}{"te\\%\\_st%"}}, + }, + }, }, }, }, + { + desc: "cannot_compare_array", + op: "$starts", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_use_with_int", + op: "$starts", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) } - assert.Equal(t, expected, db.Statement.Clauses) } func TestEnds(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$ends"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` LIKE ?", Vars: []interface{}{"%test"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$ends", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` LIKE ?", Vars: []interface{}{"%te\\%\\_st"}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$ends", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_use_with_int", + op: "$ends", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestContains(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$cont"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` LIKE ?", Vars: []interface{}{"%test%"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$cont", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` LIKE ?", Vars: []interface{}{"%te\\%\\_st%"}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$cont", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_use_with_int", + op: "$cont", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestNotContains(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$excl"].Function(db, &Filter{Field: "name", Args: []string{"test"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` NOT LIKE ?", Vars: []interface{}{"%test%"}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$excl", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` NOT LIKE ?", Vars: []interface{}{"%te\\%\\_st%"}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$excl", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_use_with_int", + op: "$excl", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestIn(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$in"].Function(db, &Filter{Field: "name", Args: []string{"val1", "val2"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$in", + filter: &Filter{Field: "name", Args: []string{"val1", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, + }, + }, }, }, }, + { + desc: "cannot_compare_array", + op: "$in", + filter: &Filter{Field: "name", Args: []string{"val1", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + { + desc: "cannot_convert_to_int", + op: "$in", + filter: &Filter{Field: "name", Args: []string{"18", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) } - assert.Equal(t, expected, db.Statement.Clauses) } func TestNotIn(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$notin"].Function(db, &Filter{Field: "name", Args: []string{"val1", "val2"}}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` NOT IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$notin", + filter: &Filter{Field: "name", Args: []string{"val1", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` NOT IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$notin", + filter: &Filter{Field: "name", Args: []string{"val1", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeTextArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, + { + desc: "cannot_convert_to_int", + op: "$notin", + filter: &Filter{Field: "name", Args: []string{"18", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) } - assert.Equal(t, expected, db.Statement.Clauses) } func TestIsNull(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$isnull"].Function(db, &Filter{Field: "name"}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` IS NULL"}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$isnull", + filter: &Filter{Field: "name"}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` IS NULL"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestNotNull(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$notnull"].Function(db, &Filter{Field: "name"}, "`test_models`.`name`", DataTypeText) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`name` IS NOT NULL"}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$notnull", + filter: &Filter{Field: "name"}, + column: "`test_models`.`name`", + dataType: DataTypeText, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`name` IS NOT NULL"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestBetween(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$between"].Function(db, &Filter{Field: "age", Args: []string{"18", "25"}}, "`test_models`.`age`", DataTypeUint) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`age` BETWEEN ? AND ?", Vars: []interface{}{uint64(18), uint64(25)}}, + cases := []operatorTestCase{ + { + desc: "ok_int", + op: "$between", + filter: &Filter{Field: "age", Args: []string{"18", "25"}}, + column: "`test_models`.`age`", + dataType: DataTypeUint, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`age` BETWEEN ? AND ?", Vars: []interface{}{uint64(18), uint64(25)}}, + }, + }, + }, + }, + }, + { + desc: "ok_time", + op: "$between", + filter: &Filter{Field: "birthday", Args: []string{"2023-04-04", "2023-05-05 12:00:00"}}, + column: "`test_models`.`birthday`", + dataType: DataTypeTime, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`birthday` BETWEEN ? AND ?", Vars: []interface{}{"2023-04-04", "2023-05-05 12:00:00"}}, + }, + }, + }, + }, + }, + { + desc: "cannot_compare_array", + op: "$between", + filter: &Filter{Field: "birthday", Args: []string{"2023-04-04", "2023-05-05 12:00:00"}}, + column: "`test_models`.`birthday`", + dataType: DataTypeTimeArray, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, + { + desc: "cannot_convert_to_int", + op: "$between", + filter: &Filter{Field: "age", Args: []string{"18", "val2"}}, + column: "`test_models`.`age`", + dataType: DataTypeUint, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, + }, + }, + }, + } + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) } - assert.Equal(t, expected, db.Statement.Clauses) } func TestIsTrue(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$istrue"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeBool) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`is_active` IS TRUE"}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$istrue", + filter: &Filter{Field: "is_active"}, + column: "`test_models`.`is_active`", + dataType: DataTypeBool, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`is_active` IS TRUE"}, + }, + }, }, }, }, - } - assert.Equal(t, expected, db.Statement.Clauses) - - db = openDryRunDB(t) - db = Operators["$istrue"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeText) // Unsupported type - - expected = map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "FALSE"}, + { + desc: "cannot_use_with_int", + op: "$istrue", + filter: &Filter{Field: "is_active"}, + column: "`test_models`.`is_active`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } func TestIsFalse(t *testing.T) { - db := openDryRunDB(t) - db = Operators["$isfalse"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeBool) - - expected := map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "`test_models`.`is_active` IS FALSE"}, + cases := []operatorTestCase{ + { + desc: "ok", + op: "$isfalse", + filter: &Filter{Field: "is_active"}, + column: "`test_models`.`is_active`", + dataType: DataTypeBool, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "`test_models`.`is_active` IS FALSE"}, + }, + }, }, }, }, - } - assert.Equal(t, expected, db.Statement.Clauses) - - db = openDryRunDB(t) - db = Operators["$isfalse"].Function(db, &Filter{Field: "isActive"}, "`test_models`.`is_active`", DataTypeText) // Unsupported type - - expected = map[string]clause.Clause{ - "WHERE": { - Name: "WHERE", - Expression: clause.Where{ - Exprs: []clause.Expression{ - clause.Expr{SQL: "FALSE"}, + { + desc: "cannot_use_with_int", + op: "$isfalse", + filter: &Filter{Field: "is_active"}, + column: "`test_models`.`is_active`", + dataType: DataTypeInt, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "FALSE"}, + }, + }, }, }, }, } - assert.Equal(t, expected, db.Statement.Clauses) + + for _, c := range cases { + t.Run(c.desc, func(t *testing.T) { + db := openDryRunDB(t) + db = Operators[c.op].Function(db, c.filter, c.column, c.dataType) + assert.Equal(t, c.want, db.Statement.Clauses) + }) + } } From 2f49682cc67027fb5de42e22e60a5f1393cd6c0e Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Tue, 4 Apr 2023 17:27:57 +0200 Subject: [PATCH 05/12] Update custom operator example --- README.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 428dea1..357da62 100644 --- a/README.md +++ b/README.md @@ -267,6 +267,8 @@ Available broad types are: If not provided, the type will be determined from GORM's data type (defined by the `gorm:"type:..."` tag). If GORM's data type is a database type or a type that is not directly supported by this library, the type will fall back to `-` (unsupported). +If the user input cannot be used with the requested column, the built-in operators will generate a `FALSE` condition. + **Example** ```go type MyModel struct{ @@ -305,7 +307,7 @@ import ( filter.Operators["$cont"] = &filter.Operator{ Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB { - if dataType != schema.String || dataType.IsArray() { + if dataType != filter.DataTypeString { return tx.Where("FALSE") } query := column + " LIKE ?" @@ -322,7 +324,7 @@ filter.Operators["$eq"] = &filter.Operator{ } arg, ok := filter.ConvertToSafeType(f.Args[0], dataType) if !ok { - return tx + return tx.Where("FALSE") } query := fmt.Sprintf("%s = ?", column, op) return f.Where(tx, query, arg) From 333a8929c468bd98b76661f4a983f945e9f97a3b Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Wed, 5 Apr 2023 11:03:09 +0200 Subject: [PATCH 06/12] Add docs on computed json fields --- README.md | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 357da62..8f9bac1 100644 --- a/README.md +++ b/README.md @@ -208,11 +208,11 @@ Internally, `goyave.dev/filter` uses [Goyave's `Paginator`](https://goyave.dev/g Sometimes you need to work with a "virtual" column that is not stored in your database, but is computed using an SQL expression. A dynamic status depending on a date for example. In order to support the features of this library properly, you will have to add the expression to your model using the `computed` struct tag: ```go -type MyModel struct{ +type MyModel struct { ID uint // ... StartDate time.Time - Status string `gorm:"->;-:migration" computed:"CASE WHEN ~~~ct~~~.start_date < NOW() THEN 'pending' ELSE 'started' END"` + Status string `gorm:"->;-:migration" computed:"CASE WHEN ~~~ct~~~.start_date < NOW() THEN 'pending' ELSE 'started' END"` } ``` @@ -232,6 +232,20 @@ type MyModelWithStatus struct{ } ``` +When using JSON columns, you can support filters on nested fields inside that JSON column using a computed column: + +```go +// This example is compatible with PostgreSQL. +// JSON processing may be different if you are using another database engine. +type MyModel struct { + ID uint + JSONColumn datatypes.JSON + SomeJSONField null.Int `gorm:"->;-:migration" computed:"(~~~ct~~~.json_column->>'fieldName')::int"` +} +``` + +It is important to make sure your JSON expression returns a value that has a type that matches the struct field to avoid DB errors. Database engines usually only return text types from JSON. If your field is a number, you'll have to cast it or you will get database errors when filtering on this field. + ## Security - Inputs are escaped to prevent SQL injections. From cb3c4a171b6a7358ae21e8e25b4e7d63f250c520 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Wed, 5 Apr 2023 11:33:14 +0200 Subject: [PATCH 07/12] Add RFC3339Nano to supported formats for time --- util.go | 2 +- util_test.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/util.go b/util.go index 6dc38b3..a08695e 100644 --- a/util.go +++ b/util.go @@ -149,7 +149,7 @@ func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { } func validateTime(timeStr string) bool { - for _, format := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02"} { + for _, format := range []string{time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02"} { _, err := time.Parse(format, timeStr) if err == nil { return true diff --git a/util_test.go b/util_test.go index b003188..758602a 100644 --- a/util_test.go +++ b/util_test.go @@ -123,12 +123,14 @@ func TestConvertToSafeType(t *testing.T) { {value: "2023-03-23", dataType: DataTypeTime, want: "2023-03-23", wantOk: true}, {value: "2023-03-23 12:13:24", dataType: DataTypeTime, want: "2023-03-23 12:13:24", wantOk: true}, {value: "2023-03-23T12:13:24Z", dataType: DataTypeTime, want: "2023-03-23T12:13:24Z", wantOk: true}, + {value: "2022-11-02T09:12:03.081967+01:00", dataType: DataTypeTime, want: "2022-11-02T09:12:03.081967+01:00", wantOk: true}, {value: "2023-03-23T12:13:24", dataType: DataTypeTime, want: nil, wantOk: false}, {value: "not a date", dataType: DataTypeTime, want: nil, wantOk: false}, {value: "1234", dataType: DataTypeTime, want: nil, wantOk: false}, {value: "2023-03-23", dataType: DataTypeTimeArray, want: "2023-03-23", wantOk: true}, {value: "2023-03-23 12:13:24", dataType: DataTypeTimeArray, want: "2023-03-23 12:13:24", wantOk: true}, {value: "2023-03-23T12:13:24Z", dataType: DataTypeTimeArray, want: "2023-03-23T12:13:24Z", wantOk: true}, + {value: "2022-11-02T09:12:03.081967+01:00", dataType: DataTypeTimeArray, want: "2022-11-02T09:12:03.081967+01:00", wantOk: true}, {value: "2023-03-23T12:13:24", dataType: DataTypeTimeArray, want: nil, wantOk: false}, {value: "not a date", dataType: DataTypeTimeArray, want: nil, wantOk: false}, {value: "1234", dataType: DataTypeTimeArray, want: nil, wantOk: false}, From 896212f0c1608ef813f5fc16dd5a1931aea09029 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Wed, 5 Apr 2023 14:46:21 +0200 Subject: [PATCH 08/12] Add documentation on how to support array operators --- README.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8f9bac1..1efe929 100644 --- a/README.md +++ b/README.md @@ -268,7 +268,7 @@ It is important to make sure your JSON expression returns a value that has a typ ### Filter type -For non-primitive types (such as `*null.Time`), you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. +For non-native types (such as `*null.Time`), you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. Available broad types are: - `text` / `text[]` @@ -347,6 +347,66 @@ filter.Operators["$eq"] = &filter.Operator{ } ``` +#### Array operators + +Some database engines such as PostgreSQL provide operators for array operations (`@>`, `&&`, ...). You may encounter issue implementing these operators in your project because of GORM converting slices into records (`("a", "b")` instead of `{"a", "b"}`). + +To fix this issue, you will have to implement your own variant of `ConvertArgsToSafeType` so it returns a **pointer** to a slice with a concrete type instead of `[]interface{}`. By sending a pointer to GORM, it won't try to render the slice itself and pass it directly to the underlying driver, which usually knows how to handle slices for the native types. + +**Example** (using generics with go 1.18+): +```go +type argType interface { + string | int64 | uint64 | float64 | bool +} + +func init() { + filter.Operators["$arrayin"] = &filter.Operator{ + Function: func (tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB { + if !dataType.IsArray() { + return tx.Where("FALSE") + } + + query := fmt.Sprintf("%s @> ?", column) + switch dataType { + case filter.DataTypeTextArray, filter.DataTypeTimeArray: + return bindArrayArg[string](tx, query, f, dataType) + case filter.DataTypeFloatArray: + return bindArrayArg[float64](tx, query, f, dataType) + case filter.DataTypeUintArray: + return bindArrayArg[uint64](tx, query, f, dataType) + case filter.DataTypeIntArray: + return bindArrayArg[int64](tx, query, f, dataType) + } + + // If you need to handle DataTypeBoolArray, use pgtype.BoolArray + return tx.Where("FALSE") + }, + RequiredArguments: 1, + } +} + +func bindArrayArg[T argType](tx *gorm.DB, query string, f *filter.Filter, dataType filter.DataType) *gorm.DB { + args, ok := convertArgsToSafeTypeArray[T](f.Args, dataType) + if !ok { + return tx.Where("FALSE") + } + return f.Where(tx, query, args) +} + +func convertArgsToSafeTypeArray[T argType](args []string, dataType filter.DataType) (*[]T, bool) { + result := make([]T, 0, len(args)) + for _, arg := range args { + a, ok := filter.ConvertToSafeType(arg, dataType) + if !ok { + return nil, false + } + result = append(result, a.(T)) + } + + return &result, true +} +``` + ### Manual joins Manual joins are supported and won't clash with joins that are automatically generated by the library. That means that if needed, you can write something like described in the following piece of code. From 307a3d7eeb29062fd5bf2ea44048b05d339a19bb Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Wed, 5 Apr 2023 15:52:54 +0200 Subject: [PATCH 09/12] Consider numbers bit size for type-safety --- README.md | 18 ++-- operator_test.go | 48 +++++------ util.go | 150 ++++++++++++++++++++++++++-------- util_test.go | 209 +++++++++++++++++++++++++++++++++++++---------- 4 files changed, 311 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index 1efe929..9d418ed 100644 --- a/README.md +++ b/README.md @@ -268,20 +268,20 @@ It is important to make sure your JSON expression returns a value that has a typ ### Filter type -For non-native types (such as `*null.Time`), you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. +For non-native types that don't implement the `driver.Valuer` interface, you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. It is also recommended to always add this tag when working with arrays. Available broad types are: - `text` / `text[]` - `bool` / `bool[]` -- `int` / `int[]` -- `uint` / `uint[]` -- `float` / `float[]` +- `int8` / `int8[]`, `int16` / `int16[]`, `int32` / `int32[]`, `int64` / `int64[]` +- `uint` / `uint[]`, `uint16` / `uint16[]`, `uint32` / `uint32[]`, `uint64` / `uint64[]` +- `float32` / `float32[]`, `float64` / `float64[]` - `time` / `time[]` - `-`: unsupported data type. Fields tagged with `-` will be ignored in filters and search: no condition will be added to the `WHERE` clause. -If not provided, the type will be determined from GORM's data type (defined by the `gorm:"type:..."` tag). If GORM's data type is a database type or a type that is not directly supported by this library, the type will fall back to `-` (unsupported). +If not provided, the type will be determined from GORM's data type. If GORM's data type is a custom type that is not directly supported by this library, the type will fall back to `-` (unsupported) and the field will be ignored in the filters. -If the user input cannot be used with the requested column, the built-in operators will generate a `FALSE` condition. +If the type is supported but the user input cannot be used with the requested column, the built-in operators will generate a `FALSE` condition. **Example** ```go @@ -370,11 +370,11 @@ func init() { switch dataType { case filter.DataTypeTextArray, filter.DataTypeTimeArray: return bindArrayArg[string](tx, query, f, dataType) - case filter.DataTypeFloatArray: + case filter.DataTypeFloat32Array, filter.DataTypeFloat64Array: return bindArrayArg[float64](tx, query, f, dataType) - case filter.DataTypeUintArray: + case filter.DataTypeUint8Array, filter.DataTypeUint16Array, filter.DataTypeUint32Array, filter.DataTypeUint64Array: return bindArrayArg[uint64](tx, query, f, dataType) - case filter.DataTypeIntArray: + case filter.DataTypeInt8Array, filter.DataTypeInt16Array, filter.DataTypeInt32Array, filter.DataTypeInt64Array: return bindArrayArg[int64](tx, query, f, dataType) } diff --git a/operator_test.go b/operator_test.go index 363f962..0d9ecca 100644 --- a/operator_test.go +++ b/operator_test.go @@ -57,7 +57,7 @@ func TestEquals(t *testing.T) { op: "$eq", filter: &Filter{Field: "age", Args: []string{"test"}}, column: "`test_models`.`age`", - dataType: DataTypeFloat, + dataType: DataTypeUint64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -121,7 +121,7 @@ func TestNotEquals(t *testing.T) { op: "$ne", filter: &Filter{Field: "age", Args: []string{"test"}}, column: "`test_models`.`age`", - dataType: DataTypeFloat, + dataType: DataTypeUint64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -151,7 +151,7 @@ func TestGreaterThan(t *testing.T) { op: "$gt", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -168,7 +168,7 @@ func TestGreaterThan(t *testing.T) { op: "$gt", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeIntArray, + dataType: DataTypeInt64Array, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -185,7 +185,7 @@ func TestGreaterThan(t *testing.T) { op: "$gt", filter: &Filter{Field: "age", Args: []string{"test"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -215,7 +215,7 @@ func TestLowerThan(t *testing.T) { op: "$lt", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -232,7 +232,7 @@ func TestLowerThan(t *testing.T) { op: "$lt", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeIntArray, + dataType: DataTypeInt64Array, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -249,7 +249,7 @@ func TestLowerThan(t *testing.T) { op: "$lt", filter: &Filter{Field: "age", Args: []string{"test"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -279,7 +279,7 @@ func TestGreaterThanEqual(t *testing.T) { op: "$gte", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -296,7 +296,7 @@ func TestGreaterThanEqual(t *testing.T) { op: "$gte", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeIntArray, + dataType: DataTypeInt64Array, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -313,7 +313,7 @@ func TestGreaterThanEqual(t *testing.T) { op: "$gte", filter: &Filter{Field: "age", Args: []string{"test"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -343,7 +343,7 @@ func TestLowerThanEqual(t *testing.T) { op: "$lte", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -360,7 +360,7 @@ func TestLowerThanEqual(t *testing.T) { op: "$lte", filter: &Filter{Field: "age", Args: []string{"18"}}, column: "`test_models`.`age`", - dataType: DataTypeIntArray, + dataType: DataTypeInt64Array, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -377,7 +377,7 @@ func TestLowerThanEqual(t *testing.T) { op: "$lte", filter: &Filter{Field: "age", Args: []string{"test"}}, column: "`test_models`.`age`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -441,7 +441,7 @@ func TestStarts(t *testing.T) { op: "$starts", filter: &Filter{Field: "name", Args: []string{"te%_st"}}, column: "`test_models`.`name`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -505,7 +505,7 @@ func TestEnds(t *testing.T) { op: "$ends", filter: &Filter{Field: "name", Args: []string{"te%_st"}}, column: "`test_models`.`name`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -569,7 +569,7 @@ func TestContains(t *testing.T) { op: "$cont", filter: &Filter{Field: "name", Args: []string{"te%_st"}}, column: "`test_models`.`name`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -633,7 +633,7 @@ func TestNotContains(t *testing.T) { op: "$excl", filter: &Filter{Field: "name", Args: []string{"te%_st"}}, column: "`test_models`.`name`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -697,7 +697,7 @@ func TestIn(t *testing.T) { op: "$in", filter: &Filter{Field: "name", Args: []string{"18", "val2"}}, column: "`test_models`.`name`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -761,7 +761,7 @@ func TestNotIn(t *testing.T) { op: "$notin", filter: &Filter{Field: "name", Args: []string{"18", "val2"}}, column: "`test_models`.`name`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -851,7 +851,7 @@ func TestBetween(t *testing.T) { op: "$between", filter: &Filter{Field: "age", Args: []string{"18", "25"}}, column: "`test_models`.`age`", - dataType: DataTypeUint, + dataType: DataTypeUint64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -902,7 +902,7 @@ func TestBetween(t *testing.T) { op: "$between", filter: &Filter{Field: "age", Args: []string{"18", "val2"}}, column: "`test_models`.`age`", - dataType: DataTypeUint, + dataType: DataTypeUint64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -949,7 +949,7 @@ func TestIsTrue(t *testing.T) { op: "$istrue", filter: &Filter{Field: "is_active"}, column: "`test_models`.`is_active`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", @@ -996,7 +996,7 @@ func TestIsFalse(t *testing.T) { op: "$isfalse", filter: &Filter{Field: "is_active"}, column: "`test_models`.`is_active`", - dataType: DataTypeInt, + dataType: DataTypeInt64, want: map[string]clause.Clause{ "WHERE": { Name: "WHERE", diff --git a/util.go b/util.go index a08695e..74cbcbc 100644 --- a/util.go +++ b/util.go @@ -21,18 +21,37 @@ func (d DataType) IsArray() bool { // Supported DataTypes const ( - DataTypeText DataType = "text" - DataTypeTextArray DataType = "text[]" - DataTypeBool DataType = "bool" - DataTypeBoolArray DataType = "bool[]" - DataTypeInt DataType = "int" - DataTypeIntArray DataType = "int[]" - DataTypeUint DataType = "uint" - DataTypeUintArray DataType = "uint[]" - DataTypeFloat DataType = "float" - DataTypeFloatArray DataType = "float[]" - DataTypeTime DataType = "time" - DataTypeTimeArray DataType = "time[]" + DataTypeText DataType = "text" + DataTypeTextArray DataType = "text[]" + + DataTypeBool DataType = "bool" + DataTypeBoolArray DataType = "bool[]" + + DataTypeInt8 DataType = "int8" + DataTypeInt8Array DataType = "int8[]" + DataTypeInt16 DataType = "int16" + DataTypeInt16Array DataType = "int16[]" + DataTypeInt32 DataType = "int32" + DataTypeInt32Array DataType = "int32[]" + DataTypeInt64 DataType = "int64" + DataTypeInt64Array DataType = "int64[]" + + DataTypeUint8 DataType = "uint8" + DataTypeUint8Array DataType = "uint8[]" + DataTypeUint16 DataType = "uint16" + DataTypeUint16Array DataType = "uint16[]" + DataTypeUint32 DataType = "uint32" + DataTypeUint32Array DataType = "uint32[]" + DataTypeUint64 DataType = "uint64" + DataTypeUint64Array DataType = "uint64[]" + + DataTypeFloat32 DataType = "float32" + DataTypeFloat32Array DataType = "float32[]" + DataTypeFloat64 DataType = "float64" + DataTypeFloat64Array DataType = "float64[]" + + DataTypeTime DataType = "time" + DataTypeTimeArray DataType = "time[]" // DataTypeUnsupported all fields with this tag will be ignored in filters and search. DataTypeUnsupported DataType = "-" @@ -84,22 +103,55 @@ func columnsContain(fields []*schema.Field, field *schema.Field) bool { func getDataType(field *schema.Field) DataType { fromTag := DataType(strings.ToLower(field.Tag.Get("filterType"))) switch fromTag { - case DataTypeText, DataTypeBool, DataTypeFloat, DataTypeInt, DataTypeUint, DataTypeTime, - DataTypeTextArray, DataTypeBoolArray, DataTypeFloatArray, DataTypeIntArray, DataTypeUintArray, DataTypeTimeArray, + case DataTypeText, DataTypeTextArray, + DataTypeBool, DataTypeBoolArray, + DataTypeFloat32, DataTypeFloat32Array, + DataTypeFloat64, DataTypeFloat64Array, + DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, + DataTypeInt8Array, DataTypeInt16Array, DataTypeInt32Array, DataTypeInt64Array, + DataTypeUint8, DataTypeUint16, DataTypeUint32, DataTypeUint64, + DataTypeUint8Array, DataTypeUint16Array, DataTypeUint32Array, DataTypeUint64Array, + DataTypeTime, DataTypeTimeArray, DataTypeUnsupported: return fromTag case "": - switch field.DataType { + switch field.GORMDataType { case schema.String: return DataTypeText case schema.Bool: return DataTypeBool case schema.Float: - return DataTypeFloat + switch field.Size { + case 32: + return DataTypeFloat32 + case 64: + return DataTypeFloat64 + } + return DataTypeFloat64 case schema.Int: - return DataTypeInt + switch field.Size { + case 8: + return DataTypeInt8 + case 16: + return DataTypeInt16 + case 32: + return DataTypeInt32 + case 64: + return DataTypeInt64 + } + return DataTypeInt64 case schema.Uint: - return DataTypeUint + switch field.Size { + case 8: + return DataTypeUint8 + case 16: + return DataTypeUint16 + case 32: + return DataTypeUint32 + case 64: + return DataTypeUint64 + } + return DataTypeUint64 case schema.Time: return DataTypeTime } @@ -122,24 +174,26 @@ func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { return false, true } return nil, false - case DataTypeFloat, DataTypeFloatArray: - i, err := strconv.ParseFloat(arg, 64) - if err != nil { - return nil, false - } - return i, true - case DataTypeInt, DataTypeIntArray: - i, err := strconv.ParseInt(arg, 10, 64) // TODO check it works on smallint - if err != nil { - return nil, false - } - return i, true - case DataTypeUint, DataTypeUintArray: - i, err := strconv.ParseUint(arg, 10, 64) - if err != nil { - return nil, false - } - return i, true + case DataTypeFloat32, DataTypeFloat32Array: + return validateFloat(arg, 32) + case DataTypeFloat64, DataTypeFloat64Array: + return validateFloat(arg, 64) + case DataTypeInt8, DataTypeInt8Array: + return validateInt(arg, 8) + case DataTypeInt16, DataTypeInt16Array: + return validateInt(arg, 16) + case DataTypeInt32, DataTypeInt32Array: + return validateInt(arg, 32) + case DataTypeInt64, DataTypeInt64Array: + return validateInt(arg, 64) + case DataTypeUint8, DataTypeUint8Array: + return validateUint(arg, 8) + case DataTypeUint16, DataTypeUint16Array: + return validateUint(arg, 16) + case DataTypeUint32, DataTypeUint32Array: + return validateUint(arg, 32) + case DataTypeUint64, DataTypeUint64Array: + return validateUint(arg, 64) case DataTypeTime, DataTypeTimeArray: if validateTime(arg) { return arg, true @@ -148,6 +202,30 @@ func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { return nil, false } +func validateInt(arg string, bitSize int) (int64, bool) { + i, err := strconv.ParseInt(arg, 10, bitSize) + if err != nil { + return 0, false + } + return i, true +} + +func validateUint(arg string, bitSize int) (uint64, bool) { + i, err := strconv.ParseUint(arg, 10, bitSize) + if err != nil { + return 0, false + } + return i, true +} + +func validateFloat(arg string, bitSize int) (float64, bool) { + i, err := strconv.ParseFloat(arg, bitSize) + if err != nil { + return 0, false + } + return i, true +} + func validateTime(timeStr string) bool { for _, format := range []string{time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02"} { _, err := time.Parse(format, timeStr) diff --git a/util_test.go b/util_test.go index 758602a..f19c322 100644 --- a/util_test.go +++ b/util_test.go @@ -3,6 +3,7 @@ package filter import ( "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "gorm.io/gorm/schema" @@ -89,35 +90,114 @@ func TestConvertToSafeType(t *testing.T) { {value: "no", dataType: DataTypeBoolArray, want: false, wantOk: true}, {value: "not a bool", dataType: DataTypeBoolArray, want: nil, wantOk: false}, - // Float - {value: "1", dataType: DataTypeFloat, want: 1.0, wantOk: true}, - {value: "1.0", dataType: DataTypeFloat, want: 1.0, wantOk: true}, - {value: "1.23", dataType: DataTypeFloat, want: 1.23, wantOk: true}, - {value: "string", dataType: DataTypeFloat, want: nil, wantOk: false}, - {value: "1", dataType: DataTypeFloatArray, want: 1.0, wantOk: true}, - {value: "1.0", dataType: DataTypeFloatArray, want: 1.0, wantOk: true}, - {value: "1.23", dataType: DataTypeFloatArray, want: 1.23, wantOk: true}, - {value: "string", dataType: DataTypeFloatArray, want: nil, wantOk: false}, + // Float32 + {value: "1", dataType: DataTypeFloat32, want: 1.0, wantOk: true}, + {value: "1.0", dataType: DataTypeFloat32, want: 1.0, wantOk: true}, + {value: "1.23", dataType: DataTypeFloat32, want: 1.2300000190734863, wantOk: true}, // Precision loss + {value: "string", dataType: DataTypeFloat32, want: 0.0, wantOk: false}, + {value: "1", dataType: DataTypeFloat32Array, want: 1.0, wantOk: true}, + {value: "1.0", dataType: DataTypeFloat32Array, want: 1.0, wantOk: true}, + {value: "1.23", dataType: DataTypeFloat32Array, want: 1.2300000190734863, wantOk: true}, // Precision loss + {value: "string", dataType: DataTypeFloat32Array, want: 0.0, wantOk: false}, - // Int - {value: "1", dataType: DataTypeInt, want: int64(1), wantOk: true}, - {value: "-2", dataType: DataTypeInt, want: int64(-2), wantOk: true}, - {value: "1.23", dataType: DataTypeInt, want: nil, wantOk: false}, - {value: "string", dataType: DataTypeInt, want: nil, wantOk: false}, - {value: "1", dataType: DataTypeIntArray, want: int64(1), wantOk: true}, - {value: "-2", dataType: DataTypeIntArray, want: int64(-2), wantOk: true}, - {value: "1.23", dataType: DataTypeIntArray, want: nil, wantOk: false}, - {value: "string", dataType: DataTypeIntArray, want: nil, wantOk: false}, + // Float64 + {value: "1", dataType: DataTypeFloat64, want: 1.0, wantOk: true}, + {value: "1.0", dataType: DataTypeFloat64, want: 1.0, wantOk: true}, + {value: "1.23", dataType: DataTypeFloat64, want: 1.23, wantOk: true}, + {value: "string", dataType: DataTypeFloat64, want: 0.0, wantOk: false}, + {value: "1", dataType: DataTypeFloat64Array, want: 1.0, wantOk: true}, + {value: "1.0", dataType: DataTypeFloat64Array, want: 1.0, wantOk: true}, + {value: "1.23", dataType: DataTypeFloat64Array, want: 1.23, wantOk: true}, + {value: "string", dataType: DataTypeFloat64Array, want: 0.0, wantOk: false}, - // Uint - {value: "1", dataType: DataTypeUint, want: uint64(1), wantOk: true}, - {value: "-2", dataType: DataTypeUint, want: nil, wantOk: false}, - {value: "1.23", dataType: DataTypeUint, want: nil, wantOk: false}, - {value: "string", dataType: DataTypeUint, want: nil, wantOk: false}, - {value: "1", dataType: DataTypeUintArray, want: uint64(1), wantOk: true}, - {value: "-2", dataType: DataTypeUintArray, want: nil, wantOk: false}, - {value: "1.23", dataType: DataTypeUintArray, want: nil, wantOk: false}, - {value: "string", dataType: DataTypeUintArray, want: nil, wantOk: false}, + // Int8 + {value: "1", dataType: DataTypeInt8, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt8, want: int64(-2), wantOk: true}, + {value: "128", dataType: DataTypeInt8, want: int64(0), wantOk: false}, + {value: "-129", dataType: DataTypeInt8, want: int64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeInt8, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt8, want: int64(0), wantOk: false}, + {value: "1", dataType: DataTypeInt8Array, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt8Array, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeInt8Array, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt8Array, want: int64(0), wantOk: false}, + + // Int16 + {value: "1", dataType: DataTypeInt16, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt16, want: int64(-2), wantOk: true}, + {value: "32768", dataType: DataTypeInt16, want: int64(0), wantOk: false}, + {value: "-32769", dataType: DataTypeInt16, want: int64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeInt16, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt16, want: int64(0), wantOk: false}, + {value: "1", dataType: DataTypeInt16Array, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt16Array, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeInt16Array, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt16Array, want: int64(0), wantOk: false}, + + // Int32 + {value: "1", dataType: DataTypeInt32, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt32, want: int64(-2), wantOk: true}, + {value: "2147483648", dataType: DataTypeInt32, want: int64(0), wantOk: false}, + {value: "-2147483649", dataType: DataTypeInt32, want: int64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeInt32, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt32, want: int64(0), wantOk: false}, + {value: "1", dataType: DataTypeInt32Array, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt32Array, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeInt32Array, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt32Array, want: int64(0), wantOk: false}, + + // Int64 + {value: "1", dataType: DataTypeInt64, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt64, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeInt64, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt64, want: int64(0), wantOk: false}, + {value: "1", dataType: DataTypeInt64Array, want: int64(1), wantOk: true}, + {value: "-2", dataType: DataTypeInt64Array, want: int64(-2), wantOk: true}, + {value: "1.23", dataType: DataTypeInt64Array, want: int64(0), wantOk: false}, + {value: "string", dataType: DataTypeInt64Array, want: int64(0), wantOk: false}, + + // Uint8 + {value: "1", dataType: DataTypeUint8, want: uint64(1), wantOk: true}, + {value: "256", dataType: DataTypeUint8, want: uint64(0), wantOk: false}, + {value: "-2", dataType: DataTypeUint8, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint8, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint8, want: uint64(0), wantOk: false}, + {value: "1", dataType: DataTypeUint8Array, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUint8Array, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint8Array, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint8Array, want: uint64(0), wantOk: false}, + + // Uint16 + {value: "1", dataType: DataTypeUint16, want: uint64(1), wantOk: true}, + {value: "65536", dataType: DataTypeUint16, want: uint64(0), wantOk: false}, + {value: "-2", dataType: DataTypeUint16, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint16, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint16, want: uint64(0), wantOk: false}, + {value: "1", dataType: DataTypeUint16Array, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUint16Array, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint16Array, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint16Array, want: uint64(0), wantOk: false}, + + // Uint32 + {value: "1", dataType: DataTypeUint32, want: uint64(1), wantOk: true}, + {value: "4294967296", dataType: DataTypeUint32, want: uint64(0), wantOk: false}, + {value: "-2", dataType: DataTypeUint32, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint32, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint32, want: uint64(0), wantOk: false}, + {value: "1", dataType: DataTypeUint32Array, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUint32Array, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint32Array, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint32Array, want: uint64(0), wantOk: false}, + + // Uint64 + {value: "1", dataType: DataTypeUint64, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUint64, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint64, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint64, want: uint64(0), wantOk: false}, + {value: "1", dataType: DataTypeUint64Array, want: uint64(1), wantOk: true}, + {value: "-2", dataType: DataTypeUint64Array, want: uint64(0), wantOk: false}, + {value: "1.23", dataType: DataTypeUint64Array, want: uint64(0), wantOk: false}, + {value: "string", dataType: DataTypeUint64Array, want: uint64(0), wantOk: false}, // Time {value: "2023-03-23", dataType: DataTypeTime, want: "2023-03-23", wantOk: true}, @@ -160,8 +240,8 @@ func TestConvertArgsToSafeType(t *testing.T) { wantOk bool }{ {value: []string{"a", "b"}, dataType: DataTypeText, want: []interface{}{"a", "b"}, wantOk: true}, - {value: []string{"3", "4"}, dataType: DataTypeInt, want: []interface{}{int64(3), int64(4)}, wantOk: true}, - {value: []string{"a", "2"}, dataType: DataTypeInt, want: []interface{}(nil), wantOk: false}, + {value: []string{"3", "4"}, dataType: DataTypeInt64, want: []interface{}{int64(3), int64(4)}, wantOk: true}, + {value: []string{"a", "2"}, dataType: DataTypeInt64, want: []interface{}(nil), wantOk: false}, } for _, c := range cases { @@ -180,7 +260,6 @@ func TestGetDataType(t *testing.T) { model interface{} want DataType }{ - {desc: "no specified type", model: struct{ Field string }{}, want: DataTypeText}, {desc: "gorm type string", model: struct { Field string `gorm:"type:string"` }{}, want: DataTypeText}, @@ -189,13 +268,13 @@ func TestGetDataType(t *testing.T) { }{}, want: DataTypeBool}, {desc: "gorm type int", model: struct { Field string `gorm:"type:int"` - }{}, want: DataTypeInt}, + }{}, want: DataTypeInt64}, {desc: "gorm type uint", model: struct { Field string `gorm:"type:uint"` - }{}, want: DataTypeUint}, + }{}, want: DataTypeUint64}, {desc: "gorm type float", model: struct { Field string `gorm:"type:float"` - }{}, want: DataTypeFloat}, + }{}, want: DataTypeFloat64}, {desc: "gorm type time", model: struct { Field string `gorm:"type:time"` }{}, want: DataTypeTime}, @@ -206,6 +285,46 @@ func TestGetDataType(t *testing.T) { Field string `gorm:"type:CHARACTER VARYING(255)"` }{}, want: DataTypeUnsupported}, + {desc: "gorm auto type string", model: struct { + Field string + }{}, want: DataTypeText}, + {desc: "gorm auto type bool", model: struct { + Field bool + }{}, want: DataTypeBool}, + {desc: "gorm auto type int8", model: struct { + Field int8 + }{}, want: DataTypeInt8}, + {desc: "gorm auto type int16", model: struct { + Field int16 + }{}, want: DataTypeInt16}, + {desc: "gorm auto type int32", model: struct { + Field int32 + }{}, want: DataTypeInt32}, + {desc: "gorm auto type int64", model: struct { + Field int64 + }{}, want: DataTypeInt64}, + {desc: "gorm auto type uint8", model: struct { + Field uint8 + }{}, want: DataTypeUint8}, + {desc: "gorm auto type uint16", model: struct { + Field uint16 + }{}, want: DataTypeUint16}, + {desc: "gorm auto type uint32", model: struct { + Field uint32 + }{}, want: DataTypeUint32}, + {desc: "gorm auto type uint64", model: struct { + Field uint64 + }{}, want: DataTypeUint64}, + {desc: "gorm auto type float32", model: struct { + Field float32 + }{}, want: DataTypeFloat32}, + {desc: "gorm auto type float64", model: struct { + Field float64 + }{}, want: DataTypeFloat64}, + {desc: "gorm auto type time", model: struct { + Field time.Time + }{}, want: DataTypeTime}, + {desc: "filter type unsupported", model: struct { Field string `filterType:"-"` }{}, want: DataTypeUnsupported}, @@ -225,23 +344,23 @@ func TestGetDataType(t *testing.T) { Field string `filterType:"bool[]"` }{}, want: DataTypeBoolArray}, {desc: "filter type float", model: struct { - Field string `filterType:"float"` - }{}, want: DataTypeFloat}, + Field string `filterType:"float64"` + }{}, want: DataTypeFloat64}, {desc: "filter type float array", model: struct { - Field string `filterType:"float[]"` - }{}, want: DataTypeFloatArray}, + Field string `filterType:"float64[]"` + }{}, want: DataTypeFloat64Array}, {desc: "filter type int", model: struct { - Field string `filterType:"int"` - }{}, want: DataTypeInt}, + Field string `filterType:"int64"` + }{}, want: DataTypeInt64}, {desc: "filter type int array", model: struct { - Field string `filterType:"int[]"` - }{}, want: DataTypeIntArray}, + Field string `filterType:"int64[]"` + }{}, want: DataTypeInt64Array}, {desc: "filter type uint", model: struct { - Field string `filterType:"uint"` - }{}, want: DataTypeUint}, + Field string `filterType:"uint64"` + }{}, want: DataTypeUint64}, {desc: "filter type uint array", model: struct { - Field string `filterType:"uint[]"` - }{}, want: DataTypeUintArray}, + Field string `filterType:"uint64[]"` + }{}, want: DataTypeUint64Array}, {desc: "filter type time", model: struct { Field string `filterType:"time"` }{}, want: DataTypeTime}, From 0ac0b7bd2af43188bb24a6e8d0ae172a1ab3407a Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Wed, 5 Apr 2023 16:04:06 +0200 Subject: [PATCH 10/12] Fix type-safety tests --- filter_test.go | 6 +++--- search_test.go | 18 +++++++++--------- util.go | 3 --- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/filter_test.go b/filter_test.go index 079f553..60b2bd5 100644 --- a/filter_test.go +++ b/filter_test.go @@ -53,7 +53,7 @@ func TestFilterScope(t *testing.T) { schema := &schema.Schema{ DBNames: []string{"name"}, FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name", DBName: "name", DataType: schema.String}, + "name": {Name: "Name", DBName: "name", GORMDataType: schema.String}, }, Table: "test_scope_models", } @@ -84,7 +84,7 @@ func TestFilterScopeBlacklisted(t *testing.T) { schema := &schema.Schema{ DBNames: []string{"name"}, FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name"}, + "name": {Name: "Name", GORMDataType: schema.String}, }, } @@ -627,7 +627,7 @@ func TestFilterScopeWithUnsupportedDataType(t *testing.T) { schema := &schema.Schema{ DBNames: []string{"name"}, FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name", DBName: "name", DataType: "CHARACTER VARYING(255)"}, + "name": {Name: "Name", DBName: "name", GORMDataType: "custom", DataType: "CHARACTER VARYING(255)"}, }, Table: "test_scope_models", } diff --git a/search_test.go b/search_test.go index 9b7b1b6..e40870c 100644 --- a/search_test.go +++ b/search_test.go @@ -25,9 +25,9 @@ func TestSearchScope(t *testing.T) { schema := &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name", DBName: "name", DataType: schema.String}, - "email": {Name: "Email", DBName: "email", DataType: schema.String}, - "role": {Name: "Role", DBName: "role", DataType: schema.String}, + "name": {Name: "Name", DBName: "name", GORMDataType: schema.String}, + "email": {Name: "Email", DBName: "email", GORMDataType: schema.String}, + "role": {Name: "Role", DBName: "role", GORMDataType: schema.String}, }, Table: "test_models", } @@ -88,9 +88,9 @@ func TestSearchScopeEmptyField(t *testing.T) { } schema := &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name"}, - "email": {Name: "Email"}, - "role": {Name: "Role"}, + "name": {Name: "Name", GORMDataType: schema.String}, + "email": {Name: "Email", GORMDataType: schema.String}, + "role": {Name: "Role", GORMDataType: schema.String}, }, Table: "test_models", } @@ -454,9 +454,9 @@ func TestSearchScopeWithUnsupportedDataType(t *testing.T) { schema := &schema.Schema{ FieldsByDBName: map[string]*schema.Field{ - "name": {Name: "Name", DBName: "name", DataType: schema.String}, - "email": {Name: "Email", DBName: "email", DataType: "CHARACTER VARYING(255)"}, - "role": {Name: "Role", DBName: "role", DataType: schema.String}, + "name": {Name: "Name", DBName: "name", GORMDataType: schema.String}, + "email": {Name: "Email", DBName: "email", GORMDataType: "custom", DataType: "CHARACTER VARYING(255)"}, + "role": {Name: "Role", DBName: "role", GORMDataType: schema.String}, }, Table: "test_models", } diff --git a/util.go b/util.go index 74cbcbc..50496b8 100644 --- a/util.go +++ b/util.go @@ -127,7 +127,6 @@ func getDataType(field *schema.Field) DataType { case 64: return DataTypeFloat64 } - return DataTypeFloat64 case schema.Int: switch field.Size { case 8: @@ -139,7 +138,6 @@ func getDataType(field *schema.Field) DataType { case 64: return DataTypeInt64 } - return DataTypeInt64 case schema.Uint: switch field.Size { case 8: @@ -151,7 +149,6 @@ func getDataType(field *schema.Field) DataType { case 64: return DataTypeUint64 } - return DataTypeUint64 case schema.Time: return DataTypeTime } From 6b30f41d643de425ba16f7f975d3ae76ea4b012b Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Thu, 6 Apr 2023 11:49:36 +0200 Subject: [PATCH 11/12] Type-safety: add DataTypeEnum --- README.md | 11 ++- operator.go | 31 ++++--- operator_test.go | 221 +++++++++++++++++++++++++++++++++++++++++++++++ util.go | 6 +- util_test.go | 10 +++ 5 files changed, 264 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 9d418ed..2547d16 100644 --- a/README.md +++ b/README.md @@ -266,12 +266,13 @@ It is important to make sure your JSON expression returns a value that has a typ - Don't use `gorm.Model` and add the necessary fields manually. You get better control over json struct tags this way. - Use pointers for nullable relations and nullable fields that implement `sql.Scanner` (such as `null.Time`). -### Filter type +### Type-safety -For non-native types that don't implement the `driver.Valuer` interface, you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. It is also recommended to always add this tag when working with arrays. +For non-native types that don't implement the `driver.Valuer` interface, you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. It is also recommended to always add this tag when working with arrays. This tag is effective for the filter and search features. Available broad types are: - `text` / `text[]` +- `enum` / `enum[]`: use this with custom enum types to prevent "invalid input value" or "invalid operator" errors - `bool` / `bool[]` - `int8` / `int8[]`, `int16` / `int16[]`, `int32` / `int32[]`, `int64` / `int64[]` - `uint` / `uint[]`, `uint16` / `uint16[]`, `uint32` / `uint32[]`, `uint64` / `uint64[]` @@ -365,10 +366,14 @@ func init() { if !dataType.IsArray() { return tx.Where("FALSE") } + + if dataType == filter.DataTypeEnumArray { + column = fmt.Sprintf("CAST(%s as TEXT[])", column) + } query := fmt.Sprintf("%s @> ?", column) switch dataType { - case filter.DataTypeTextArray, filter.DataTypeTimeArray: + case filter.DataTypeTextArray, filter.DataTypeEnumArray, filter.DataTypeTimeArray: return bindArrayArg[string](tx, query, f, dataType) case filter.DataTypeFloat32Array, filter.DataTypeFloat64Array: return bindArrayArg[float64](tx, query, f, dataType) diff --git a/operator.go b/operator.go index c4ea834..d00e517 100644 --- a/operator.go +++ b/operator.go @@ -33,10 +33,10 @@ var ( "$lte": {Function: basicComparison("<="), RequiredArguments: 1}, "$starts": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { - if dataType != DataTypeText { + if dataType != DataTypeText && dataType != DataTypeEnum { return tx.Where("FALSE") } - query := column + " LIKE ?" + query := castEnumAsText(column, dataType) + " LIKE ?" value := sqlutil.EscapeLike(filter.Args[0]) + "%" return filter.Where(tx, query, value) }, @@ -44,10 +44,10 @@ var ( }, "$ends": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { - if dataType != DataTypeText { + if dataType != DataTypeText && dataType != DataTypeEnum { return tx.Where("FALSE") } - query := column + " LIKE ?" + query := castEnumAsText(column, dataType) + " LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) return filter.Where(tx, query, value) }, @@ -55,10 +55,10 @@ var ( }, "$cont": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { - if dataType != DataTypeText { + if dataType != DataTypeText && dataType != DataTypeEnum { return tx.Where("FALSE") } - query := column + " LIKE ?" + query := castEnumAsText(column, dataType) + " LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" return filter.Where(tx, query, value) }, @@ -66,10 +66,10 @@ var ( }, "$excl": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { - if dataType != DataTypeText { + if dataType != DataTypeText && dataType != DataTypeEnum { return tx.Where("FALSE") } - query := column + " NOT LIKE ?" + query := castEnumAsText(column, dataType) + " NOT LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" return filter.Where(tx, query, value) }, @@ -116,7 +116,7 @@ var ( if !ok { return tx.Where("FALSE") } - query := column + " BETWEEN ? AND ?" + query := castEnumAsText(column, dataType) + " BETWEEN ? AND ?" return filter.Where(tx, query, args...) }, RequiredArguments: 2, @@ -124,6 +124,13 @@ var ( } ) +func castEnumAsText(column string, dataType DataType) string { + if dataType == DataTypeEnum || dataType == DataTypeEnumArray { + return fmt.Sprintf("CAST(%s AS TEXT)", column) + } + return column +} + func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType.IsArray() { @@ -133,7 +140,8 @@ func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, if !ok { return tx.Where("FALSE") } - query := fmt.Sprintf("%s %s ?", column, op) + + query := fmt.Sprintf("%s %s ?", castEnumAsText(column, dataType), op) return filter.Where(tx, query, arg) } } @@ -147,7 +155,8 @@ func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, if !ok { return tx.Where("FALSE") } - query := fmt.Sprintf("%s %s ?", column, op) + + query := fmt.Sprintf("%s %s ?", castEnumAsText(column, dataType), op) return filter.Where(tx, query, args) } } diff --git a/operator_test.go b/operator_test.go index 0d9ecca..4e7725b 100644 --- a/operator_test.go +++ b/operator_test.go @@ -35,6 +35,23 @@ func TestEquals(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$eq", + filter: &Filter{Field: "name", Args: []string{"test"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) = ?", Vars: []interface{}{"test"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$eq", @@ -99,6 +116,23 @@ func TestNotEquals(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$ne", + filter: &Filter{Field: "name", Args: []string{"test"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) <> ?", Vars: []interface{}{"test"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$ne", @@ -163,6 +197,23 @@ func TestGreaterThan(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$gt", + filter: &Filter{Field: "enum_col", Args: []string{"18"}}, + column: "`test_models`.`enum_col`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`enum_col` AS TEXT) > ?", Vars: []interface{}{"18"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$gt", @@ -227,6 +278,23 @@ func TestLowerThan(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$lt", + filter: &Filter{Field: "enum_col", Args: []string{"18"}}, + column: "`test_models`.`enum_col`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`enum_col` AS TEXT) < ?", Vars: []interface{}{"18"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$lt", @@ -291,6 +359,23 @@ func TestGreaterThanEqual(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$gte", + filter: &Filter{Field: "enum_col", Args: []string{"18"}}, + column: "`test_models`.`enum_col`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`enum_col` AS TEXT) >= ?", Vars: []interface{}{"18"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$gte", @@ -355,6 +440,23 @@ func TestLowerThanEqual(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$lte", + filter: &Filter{Field: "enum_col", Args: []string{"18"}}, + column: "`test_models`.`enum_col`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`enum_col` AS TEXT) <= ?", Vars: []interface{}{"18"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$lte", @@ -419,6 +521,23 @@ func TestStarts(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$starts", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) LIKE ?", Vars: []interface{}{"te\\%\\_st%"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$starts", @@ -483,6 +602,23 @@ func TestEnds(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$ends", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) LIKE ?", Vars: []interface{}{"%te\\%\\_st"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$ends", @@ -547,6 +683,23 @@ func TestContains(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$cont", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) LIKE ?", Vars: []interface{}{"%te\\%\\_st%"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$cont", @@ -611,6 +764,23 @@ func TestNotContains(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$excl", + filter: &Filter{Field: "name", Args: []string{"te%_st"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) NOT LIKE ?", Vars: []interface{}{"%te\\%\\_st%"}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$excl", @@ -675,6 +845,23 @@ func TestIn(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$in", + filter: &Filter{Field: "name", Args: []string{"val1", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$in", @@ -739,6 +926,23 @@ func TestNotIn(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$notin", + filter: &Filter{Field: "name", Args: []string{"val1", "val2"}}, + column: "`test_models`.`name`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`name` AS TEXT) NOT IN ?", Vars: []interface{}{[]interface{}{"val1", "val2"}}}, + }, + }, + }, + }, + }, { desc: "cannot_compare_array", op: "$notin", @@ -863,6 +1067,23 @@ func TestBetween(t *testing.T) { }, }, }, + { + desc: "ok_enum", + op: "$between", + filter: &Filter{Field: "enum_col", Args: []string{"18", "25"}}, + column: "`test_models`.`enum_col`", + dataType: DataTypeEnum, + want: map[string]clause.Clause{ + "WHERE": { + Name: "WHERE", + Expression: clause.Where{ + Exprs: []clause.Expression{ + clause.Expr{SQL: "CAST(`test_models`.`enum_col` AS TEXT) BETWEEN ? AND ?", Vars: []interface{}{"18", "25"}}, + }, + }, + }, + }, + }, { desc: "ok_time", op: "$between", diff --git a/util.go b/util.go index 50496b8..c04a424 100644 --- a/util.go +++ b/util.go @@ -24,6 +24,9 @@ const ( DataTypeText DataType = "text" DataTypeTextArray DataType = "text[]" + DataTypeEnum DataType = "enum" + DataTypeEnumArray DataType = "enum[]" + DataTypeBool DataType = "bool" DataTypeBoolArray DataType = "bool[]" @@ -104,6 +107,7 @@ func getDataType(field *schema.Field) DataType { fromTag := DataType(strings.ToLower(field.Tag.Get("filterType"))) switch fromTag { case DataTypeText, DataTypeTextArray, + DataTypeEnum, DataTypeEnumArray, DataTypeBool, DataTypeBoolArray, DataTypeFloat32, DataTypeFloat32Array, DataTypeFloat64, DataTypeFloat64Array, @@ -161,7 +165,7 @@ func getDataType(field *schema.Field) DataType { // be converted. func ConvertToSafeType(arg string, dataType DataType) (interface{}, bool) { switch dataType { - case DataTypeText, DataTypeTextArray: + case DataTypeText, DataTypeTextArray, DataTypeEnum, DataTypeEnumArray: return arg, true case DataTypeBool, DataTypeBoolArray: switch arg { diff --git a/util_test.go b/util_test.go index f19c322..a71bfd3 100644 --- a/util_test.go +++ b/util_test.go @@ -70,6 +70,10 @@ func TestConvertToSafeType(t *testing.T) { {value: "string", dataType: DataTypeText, want: "string", wantOk: true}, {value: "string", dataType: DataTypeTextArray, want: "string", wantOk: true}, + // Enum + {value: "string", dataType: DataTypeEnum, want: "string", wantOk: true}, + {value: "string", dataType: DataTypeEnumArray, want: "string", wantOk: true}, + // Bool {value: "1", dataType: DataTypeBool, want: true, wantOk: true}, {value: "on", dataType: DataTypeBool, want: true, wantOk: true}, @@ -337,6 +341,12 @@ func TestGetDataType(t *testing.T) { {desc: "filter type text array", model: struct { Field string `filterType:"text[]"` }{}, want: DataTypeTextArray}, + {desc: "filter type enum", model: struct { + Field string `filterType:"enum"` + }{}, want: DataTypeEnum}, + {desc: "filter type enum array", model: struct { + Field string `filterType:"enum[]"` + }{}, want: DataTypeEnumArray}, {desc: "filter type bool", model: struct { Field string `filterType:"bool"` }{}, want: DataTypeBool}, From c104a7bb8ab2a81ab7ae6c0aba923c439df2cda9 Mon Sep 17 00:00:00 2001 From: SystemGlitch Date: Fri, 7 Apr 2023 10:09:25 +0200 Subject: [PATCH 12/12] Fix type miscmatch generate FALSE with OR filters --- operator.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/operator.go b/operator.go index d00e517..7af343d 100644 --- a/operator.go +++ b/operator.go @@ -34,7 +34,7 @@ var ( "$starts": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType != DataTypeText && dataType != DataTypeEnum { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := castEnumAsText(column, dataType) + " LIKE ?" value := sqlutil.EscapeLike(filter.Args[0]) + "%" @@ -45,7 +45,7 @@ var ( "$ends": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType != DataTypeText && dataType != DataTypeEnum { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := castEnumAsText(column, dataType) + " LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) @@ -56,7 +56,7 @@ var ( "$cont": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType != DataTypeText && dataType != DataTypeEnum { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := castEnumAsText(column, dataType) + " LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" @@ -67,7 +67,7 @@ var ( "$excl": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType != DataTypeText && dataType != DataTypeEnum { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := castEnumAsText(column, dataType) + " NOT LIKE ?" value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%" @@ -86,7 +86,7 @@ var ( "$istrue": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType != DataTypeBool { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } return filter.Where(tx, column+" IS TRUE") }, @@ -95,7 +95,7 @@ var ( "$isfalse": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType != DataTypeBool { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } return filter.Where(tx, column+" IS FALSE") }, @@ -110,11 +110,11 @@ var ( "$between": { Function: func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType.IsArray() { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } args, ok := ConvertArgsToSafeType(filter.Args[:2], dataType) if !ok { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := castEnumAsText(column, dataType) + " BETWEEN ? AND ?" return filter.Where(tx, query, args...) @@ -134,11 +134,11 @@ func castEnumAsText(column string, dataType DataType) string { func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType.IsArray() { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } arg, ok := ConvertToSafeType(filter.Args[0], dataType) if !ok { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := fmt.Sprintf("%s %s ?", castEnumAsText(column, dataType), op) @@ -149,11 +149,11 @@ func basicComparison(op string) func(tx *gorm.DB, filter *Filter, column string, func multiComparison(op string) func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { return func(tx *gorm.DB, filter *Filter, column string, dataType DataType) *gorm.DB { if dataType.IsArray() { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } args, ok := ConvertArgsToSafeType(filter.Args, dataType) if !ok { - return tx.Where("FALSE") + return filter.Where(tx, "FALSE") } query := fmt.Sprintf("%s %s ?", castEnumAsText(column, dataType), op)