From 6b0a22d3a9c41368f4023a3c376f52d15af11d82 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Tue, 13 Feb 2024 23:58:41 +0100 Subject: [PATCH] Refactor sortBy builtin --- builtin/builtin.go | 137 ++++++++++++++-------------------------- builtin/builtin_test.go | 18 +++++- builtin/sort.go | 96 ---------------------------- checker/checker.go | 26 +++++++- checker/checker_test.go | 10 --- compiler/compiler.go | 18 ++++++ parser/parser.go | 1 + vm/opcodes.go | 2 + vm/program.go | 6 ++ vm/runtime/sort.go | 45 +++++++++++++ vm/vm.go | 33 +++++++++- 11 files changed, 193 insertions(+), 199 deletions(-) delete mode 100644 builtin/sort.go create mode 100644 vm/runtime/sort.go diff --git a/builtin/builtin.go b/builtin/builtin.go index f7b7bdeb..70a8c09c 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -87,6 +87,11 @@ var Builtins = []*Function{ Predicate: true, Types: types(new(func([]any, func(any) any) map[any][]any)), }, + { + Name: "sortBy", + Predicate: true, + Types: types(new(func([]any, func(any) bool, string) []any)), + }, { Name: "reduce", Predicate: true, @@ -905,109 +910,65 @@ var Builtins = []*Function{ }, { Name: "sort", - Func: func(args ...any) (any, error) { + Safe: func(args ...any) (any, uint, error) { if len(args) != 1 && len(args) != 2 { - return nil, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args)) + return nil, 0, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args)) } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sort %s", v.Kind()) - } + var array []any - orderBy := OrderBy{} - if len(args) == 2 { - dir, err := ascOrDesc(args[1]) - if err != nil { - return nil, err + switch in := args[0].(type) { + case []any: + array = make([]any, len(in)) + copy(array, in) + case []int: + array = make([]any, len(in)) + for i, v := range in { + array[i] = v + } + case []float64: + array = make([]any, len(in)) + for i, v := range in { + array[i] = v + } + case []string: + array = make([]any, len(in)) + for i, v := range in { + array[i] = v } - orderBy.Desc = dir } - sortable, err := copyArray(v, orderBy) - if err != nil { - return nil, err - } - sort.Sort(sortable) - return sortable.Array, nil - }, - Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 && len(args) != 2 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1 or 2, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sort %s", args[0]) - } + var desc bool if len(args) == 2 { - switch kind(args[1]) { - case reflect.String, reflect.Interface: + switch args[1].(string) { + case "asc": + desc = false + case "desc": + desc = true default: - return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1]) + return nil, 0, fmt.Errorf("invalid order %s, expected asc or desc", args[1]) } } - return arrayType, nil - }, - }, - { - Name: "sortBy", - Func: func(args ...any) (any, error) { - if len(args) != 2 && len(args) != 3 { - return nil, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sort %s", v.Kind()) - } - - orderBy := OrderBy{} - - field, ok := args[1].(string) - if !ok { - return nil, fmt.Errorf("invalid argument for sort (expected string, got %s)", reflect.TypeOf(args[1])) - } - orderBy.Field = field - - if len(args) == 3 { - dir, err := ascOrDesc(args[2]) - if err != nil { - return nil, err - } - orderBy.Desc = dir - } - - sortable, err := copyArray(v, orderBy) - if err != nil { - return nil, err + sortable := &runtime.Sort{ + Desc: desc, + Array: array, } sort.Sort(sortable) - return sortable.Array, nil - }, - Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 2 && len(args) != 3 { - return anyType, fmt.Errorf("invalid number of arguments (expected 2 or 3, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sort %s", args[0]) - } - switch kind(args[1]) { - case reflect.String, reflect.Interface: - default: - return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1]) - } - if len(args) == 3 { - switch kind(args[2]) { - case reflect.String, reflect.Interface: - default: - return anyType, fmt.Errorf("invalid argument for sort (expected string, got %s)", args[1]) - } - } - return arrayType, nil + + return sortable.Array, uint(len(array)), nil }, + Types: types( + new(func([]any, string) []any), + new(func([]int, string) []any), + new(func([]float64, string) []any), + new(func([]string, string) []any), + + new(func([]any) []any), + new(func([]float64) []any), + new(func([]string) []any), + new(func([]int) []any), + ), }, bitFunc("bitand", func(x, y int) (any, error) { return x & y, nil diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index d6b967d1..87e2a7c9 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -530,8 +530,8 @@ func TestBuiltin_sort(t *testing.T) { {`sort(ArrayOfInt)`, []any{1, 2, 3}}, {`sort(ArrayOfFloat)`, []any{1.0, 2.0, 3.0}}, {`sort(ArrayOfInt, 'desc')`, []any{3, 2, 1}}, - {`sortBy(ArrayOfFoo, 'Value')`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}}, - {`sortBy([{id: "a"}, {id: "b"}], "id", "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}}, + {`sortBy(ArrayOfFoo, .Value)`, []any{mock.Foo{Value: "a"}, mock.Foo{Value: "b"}, mock.Foo{Value: "c"}}}, + {`sortBy([{id: "a"}, {id: "b"}], .id, "desc")`, []any{map[string]any{"id": "b"}, map[string]any{"id": "a"}}}, } for _, test := range tests { @@ -546,6 +546,20 @@ func TestBuiltin_sort(t *testing.T) { } } +func TestBuiltin_sort_i64(t *testing.T) { + env := map[string]any{ + "array": []int{1, 2, 3}, + "i64": int64(1), + } + + program, err := expr.Compile(`sort(map(array, i64))`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, []any{int64(1), int64(1), int64(1)}, out) +} + func TestBuiltin_bitOpsFunc(t *testing.T) { tests := []struct { input string diff --git a/builtin/sort.go b/builtin/sort.go deleted file mode 100644 index 9b9ddc16..00000000 --- a/builtin/sort.go +++ /dev/null @@ -1,96 +0,0 @@ -package builtin - -import ( - "fmt" - "reflect" -) - -type Sortable struct { - Array []any - Values []reflect.Value - OrderBy -} - -type OrderBy struct { - Field string - Desc bool -} - -func (s *Sortable) Len() int { - return len(s.Array) -} - -func (s *Sortable) Swap(i, j int) { - s.Array[i], s.Array[j] = s.Array[j], s.Array[i] - s.Values[i], s.Values[j] = s.Values[j], s.Values[i] -} - -func (s *Sortable) Less(i, j int) bool { - a, b := s.Values[i], s.Values[j] - switch a.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if s.Desc { - return a.Int() > b.Int() - } - return a.Int() < b.Int() - case reflect.Float64, reflect.Float32: - if s.Desc { - return a.Float() > b.Float() - } - return a.Float() < b.Float() - case reflect.String: - if s.Desc { - return a.String() > b.String() - } - return a.String() < b.String() - default: - panic(fmt.Sprintf("sort: unsupported type %s", a.Kind())) - } -} - -func copyArray(v reflect.Value, orderBy OrderBy) (*Sortable, error) { - s := &Sortable{ - Array: make([]any, v.Len()), - Values: make([]reflect.Value, v.Len()), - OrderBy: orderBy, - } - var prev reflect.Value - for i := 0; i < s.Len(); i++ { - elem := deref(v.Index(i)) - var value reflect.Value - switch elem.Kind() { - case reflect.Struct: - value = elem.FieldByName(s.Field) - case reflect.Map: - value = elem.MapIndex(reflect.ValueOf(s.Field)) - default: - value = elem - } - value = deref(value) - - s.Array[i] = elem.Interface() - s.Values[i] = value - - if i == 0 { - prev = value - } else if value.Type() != prev.Type() { - return nil, fmt.Errorf("cannot sort array of different types (%s and %s)", value.Type(), prev.Type()) - } - } - return s, nil -} - -func ascOrDesc(arg any) (bool, error) { - dir, ok := arg.(string) - if !ok { - return false, fmt.Errorf("invalid argument for sort (expected string, got %s)", reflect.TypeOf(arg)) - } - switch dir { - case "desc": - return true, nil - case "asc": - return false, nil - default: - return false, fmt.Errorf(`invalid argument for sort (expected "asc" or "desc", got %q)`, dir) - } -} diff --git a/checker/checker.go b/checker/checker.go index 3dc4e95a..11e4eee3 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -633,7 +633,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { if isAny(collection) { return arrayType, info{} } - return reflect.SliceOf(collection.Elem()), info{} + return arrayType, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -651,7 +651,7 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { closure.NumOut() == 1 && closure.NumIn() == 1 && isAny(closure.In(0)) { - return reflect.SliceOf(closure.Out(0)), info{} + return arrayType, info{} } return v.error(node.Arguments[1], "predicate should has one input and one output param") @@ -739,6 +739,28 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) { } return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "sortBy": + collection, _ := v.visit(node.Arguments[0]) + if !isArray(collection) && !isAny(collection) { + return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection) + } + + v.begin(collection) + closure, _ := v.visit(node.Arguments[1]) + v.end() + + if len(node.Arguments) == 3 { + _, _ = v.visit(node.Arguments[2]) + } + + if isFunc(closure) && + closure.NumOut() == 1 && + closure.NumIn() == 1 && isAny(closure.In(0)) { + + return reflect.TypeOf([]any{}), info{} + } + return v.error(node.Arguments[1], "predicate should has one input and one output param") + case "reduce": collection, _ := v.visit(node.Arguments[0]) if !isArray(collection) && !isAny(collection) { diff --git a/checker/checker_test.go b/checker/checker_test.go index 2bf5ec86..bab9a0a6 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -400,11 +400,6 @@ invalid operation: < (mismatched types mock.Bar and int) (1:29) | all(ArrayOfFoo, {#.Method() < 0}) | ............................^ -map(Any, {0})[0] + "str" -invalid operation: + (mismatched types int and string) (1:18) - | map(Any, {0})[0] + "str" - | .................^ - Variadic() not enough arguments to call Variadic (1:1) | Variadic() @@ -445,11 +440,6 @@ builtin map takes only array (got int) (1:5) | map(1, {2}) | ....^ -map(filter(ArrayOfFoo, {true}), {.Not}) -type mock.Foo has no field Not (1:35) - | map(filter(ArrayOfFoo, {true}), {.Not}) - | ..................................^ - ArrayOfFoo[Foo] array elements can only be selected using an integer (got mock.Foo) (1:12) | ArrayOfFoo[Foo] diff --git a/compiler/compiler.go b/compiler/compiler.go index e95e1db1..a4f189e6 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -883,6 +883,24 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) { c.emit(OpEnd) return + case "sortBy": + c.compile(node.Arguments[0]) + c.emit(OpBegin) + if len(node.Arguments) == 3 { + c.compile(node.Arguments[2]) + } else { + c.emit(OpPush, c.addConstant("asc")) + } + c.emit(OpCreate, 2) + c.emit(OpSetAcc) + c.emitLoop(func() { + c.compile(node.Arguments[1]) + c.emit(OpSortBy) + }) + c.emit(OpSort) + c.emit(OpEnd) + return + case "reduce": c.compile(node.Arguments[0]) c.emit(OpBegin) diff --git a/parser/parser.go b/parser/parser.go index bc620ac6..1eabdebe 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -39,6 +39,7 @@ var predicates = map[string]struct { "findLast": {[]arg{expr, closure}}, "findLastIndex": {[]arg{expr, closure}}, "groupBy": {[]arg{expr, closure}}, + "sortBy": {[]arg{expr, closure, expr | optional}}, "reduce": {[]arg{expr, closure, expr | optional}}, } diff --git a/vm/opcodes.go b/vm/opcodes.go index 1ab4ba79..0417dab6 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -79,6 +79,8 @@ const ( OpThrow OpCreate OpGroupBy + OpSortBy + OpSort OpBegin OpEnd // This opcode must be at the end of this list. ) diff --git a/vm/program.go b/vm/program.go index 55bfeb58..4a878267 100644 --- a/vm/program.go +++ b/vm/program.go @@ -354,6 +354,12 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpGroupBy: code("OpGroupBy") + case OpSortBy: + code("OpSortBy") + + case OpSort: + code("OpSort") + case OpBegin: code("OpBegin") diff --git a/vm/runtime/sort.go b/vm/runtime/sort.go new file mode 100644 index 00000000..fb1f340d --- /dev/null +++ b/vm/runtime/sort.go @@ -0,0 +1,45 @@ +package runtime + +type SortBy struct { + Desc bool + Array []any + Values []any +} + +func (s *SortBy) Len() int { + return len(s.Array) +} + +func (s *SortBy) Swap(i, j int) { + s.Array[i], s.Array[j] = s.Array[j], s.Array[i] + s.Values[i], s.Values[j] = s.Values[j], s.Values[i] +} + +func (s *SortBy) Less(i, j int) bool { + a, b := s.Values[i], s.Values[j] + if s.Desc { + return Less(b, a) + } + return Less(a, b) +} + +type Sort struct { + Desc bool + Array []any +} + +func (s *Sort) Len() int { + return len(s.Array) +} + +func (s *Sort) Swap(i, j int) { + s.Array[i], s.Array[j] = s.Array[j], s.Array[i] +} + +func (s *Sort) Less(i, j int) bool { + a, b := s.Array[i], s.Array[j] + if s.Desc { + return Less(b, a) + } + return Less(a, b) +} diff --git a/vm/vm.go b/vm/vm.go index 3ae6b5a0..56d5fc2e 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "regexp" + "sort" "strings" "github.com/expr-lang/expr/builtin" @@ -481,8 +482,23 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { case 1: vm.push(make(groupBy)) case 2: + scope := vm.scope() + var desc bool + switch vm.pop().(string) { + case "asc": + desc = false + case "desc": + desc = true + default: + panic("unknown order, use asc or desc") + } + vm.push(&runtime.SortBy{ + Desc: desc, + Array: make([]any, 0, scope.Len), + Values: make([]any, 0, scope.Len), + }) default: - panic("OpCreate: unknown type") + panic(fmt.Sprintf("unknown OpCreate argument %v", arg)) } case OpGroupBy: @@ -491,6 +507,21 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { item := scope.Array.Index(scope.Index).Interface() scope.Acc.(groupBy)[key] = append(scope.Acc.(groupBy)[key], item) + case OpSortBy: + scope := vm.scope() + value := vm.pop() + item := scope.Array.Index(scope.Index).Interface() + sortable := scope.Acc.(*runtime.SortBy) + sortable.Array = append(sortable.Array, item) + sortable.Values = append(sortable.Values, value) + + case OpSort: + scope := vm.scope() + sortable := scope.Acc.(*runtime.SortBy) + sort.Sort(sortable) + vm.memGrow(uint(scope.Len)) + vm.push(sortable.Array) + case OpBegin: a := vm.pop() array := reflect.ValueOf(a)