From f1b68c877d52fe7e5ca985055e113c9a519bb57c Mon Sep 17 00:00:00 2001 From: bizy Date: Wed, 28 Feb 2024 00:30:33 +0700 Subject: [PATCH 01/16] Compare `any` arrays --- expr_test.go | 26 +++ vm/runtime/helpers/main.go | 46 +++++ vm/runtime/helpers[generated].go | 343 +++++++++++++++++++++++++++++++ 3 files changed, 415 insertions(+) diff --git a/expr_test.go b/expr_test.go index 74975362b..ea9213be2 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2482,3 +2482,29 @@ func TestRaceCondition_variables(t *testing.T) { wg.Wait() } + +func TestArrayComparison(t *testing.T) { + tests := []struct { + env any + code string + }{ + {[]string{"A", "B"}, "foo == ['A', 'B']"}, + {[]int{1, 2}, "foo == [1, 2]"}, + {[]uint8{1, 2}, "foo == [1, 2]"}, + {[]float64{1.1, 2.2}, "foo == [1.1, 2.2]"}, + {[]any{"A", 1, 1.1, true}, "foo == ['A', 1, 1.1, true]"}, + {[]string{"A", "B"}, "foo != [1, 2]"}, + } + + for _, tt := range tests { + t.Run(tt.code, func(t *testing.T) { + env := map[string]any{"foo": tt.env} + program, err := expr.Compile(tt.code, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, true, out) + }) + } +} diff --git a/vm/runtime/helpers/main.go b/vm/runtime/helpers/main.go index b3f598a43..54a4fc235 100644 --- a/vm/runtime/helpers/main.go +++ b/vm/runtime/helpers/main.go @@ -19,6 +19,7 @@ func main() { "cases_with_duration": func(op string) string { return cases(op, uints, ints, floats, []string{"time.Duration"}) }, + "array_equal_cases": func() string { return arrayEqualCases([]string{"string"}, uints, ints, floats) }, }). Parse(helpers), ).Execute(&b, nil) @@ -89,6 +90,45 @@ func cases(op string, xs ...[]string) string { return strings.TrimRight(out, "\n") } +func arrayEqualCases(xs ...[]string) string { + var types []string + for _, x := range xs { + types = append(types, x...) + } + + _, _ = fmt.Fprintf(os.Stderr, "Generating array equal cases for %v\n", types) + + var out string + echo := func(s string, xs ...any) { + out += fmt.Sprintf(s, xs...) + "\n" + } + echo(`case []any:`) + echo(`switch y := b.(type) {`) + for _, a := range append(types, "any") { + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if !Equal(x[i], y[i]) { return false }`) + echo(`}`) + echo("return true") + } + echo(`}`) + for _, a := range types { + echo(`case []%v:`, a) + echo(`switch y := b.(type) {`) + echo(`case []any:`) + echo(`return Equal(y, x)`) + echo(`case []%v:`, a) + echo(`if len(x) != len(y) { return false }`) + echo(`for i := range x {`) + echo(`if x[i] != y[i] { return false }`) + echo(`}`) + echo("return true") + echo(`}`) + } + return strings.TrimRight(out, "\n") +} + func isFloat(t string) bool { return strings.HasPrefix(t, "float") } @@ -110,6 +150,7 @@ import ( func Equal(a, b interface{}) bool { switch x := a.(type) { {{ cases "==" }} + {{ array_equal_cases }} case string: switch y := b.(type) { case string: @@ -125,6 +166,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true diff --git a/vm/runtime/helpers[generated].go b/vm/runtime/helpers[generated].go index 720feb455..d950f1111 100644 --- a/vm/runtime/helpers[generated].go +++ b/vm/runtime/helpers[generated].go @@ -334,6 +334,344 @@ func Equal(a, b interface{}) bool { case float64: return float64(x) == float64(y) } + case []any: + switch y := b.(type) { + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + case []any: + if len(x) != len(y) { + return false + } + for i := range x { + if !Equal(x[i], y[i]) { + return false + } + } + return true + } + case []string: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []string: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []uint64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []uint64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int8: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int8: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int16: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int16: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []int64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []int64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float32: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float32: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } + case []float64: + switch y := b.(type) { + case []any: + return Equal(y, x) + case []float64: + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true + } case string: switch y := b.(type) { case string: @@ -349,6 +687,11 @@ func Equal(a, b interface{}) bool { case time.Duration: return x == y } + case bool: + switch y := b.(type) { + case bool: + return x == y + } } if IsNil(a) && IsNil(b) { return true From ef57900b163f64429fa82a542fe5662e3b41ef1e Mon Sep 17 00:00:00 2001 From: bizy Date: Tue, 5 Mar 2024 01:11:10 +0700 Subject: [PATCH 02/16] Add bench and tests for `runtime.Equal` --- vm/runtime/helpers_test.go | 57 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 vm/runtime/helpers_test.go diff --git a/vm/runtime/helpers_test.go b/vm/runtime/helpers_test.go new file mode 100644 index 000000000..42a0aece0 --- /dev/null +++ b/vm/runtime/helpers_test.go @@ -0,0 +1,57 @@ +package runtime_test + +import ( + "testing" + + "github.com/expr-lang/expr/vm/runtime" + "github.com/stretchr/testify/assert" +) + +var tests = []struct { + name string + a, b any + want bool +}{ + {"int == int", 42, 42, true}, + {"int != int", 42, 33, false}, + {"int == int8", 42, int8(42), true}, + {"int == int16", 42, int16(42), true}, + {"int == int32", 42, int32(42), true}, + {"int == int64", 42, int64(42), true}, + {"float == float", 42.0, 42.0, true}, + {"float != float", 42.0, 33.0, false}, + {"float == int", 42.0, 42, true}, + {"float != int", 42.0, 33, false}, + {"string == string", "foo", "foo", true}, + {"string != string", "foo", "bar", false}, + {"bool == bool", true, true, true}, + {"bool != bool", true, false, false}, + {"[]any == []int", []any{1, 2, 3}, []int{1, 2, 3}, true}, + {"[]any != []int", []any{1, 2, 3}, []int{1, 2, 99}, false}, + {"deep []any == []any", []any{[]int{1}, 2, []any{"3"}}, []any{[]any{1}, 2, []string{"3"}}, true}, + {"deep []any != []any", []any{[]int{1}, 2, []any{"3", "42"}}, []any{[]any{1}, 2, []string{"3"}}, false}, + {"map[string]any == map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1}, true}, + {"map[string]any != map[string]any", map[string]any{"a": 1}, map[string]any{"a": 1, "b": 2}, false}, +} + +func TestEqual(t *testing.T) { + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := runtime.Equal(tt.a, tt.b) + assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.a, tt.b, got, tt.want) + got = runtime.Equal(tt.b, tt.a) + assert.Equal(t, tt.want, got, "Equal(%v, %v) = %v; want %v", tt.b, tt.a, got, tt.want) + }) + } + +} + +func BenchmarkEqual(b *testing.B) { + for _, tt := range tests { + b.Run(tt.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + runtime.Equal(tt.a, tt.b) + } + }) + } +} From eb8fb13de1f73785d40d97b1dab8be2f7063b7ff Mon Sep 17 00:00:00 2001 From: Sergey Date: Tue, 27 Feb 2024 20:23:49 +0700 Subject: [PATCH 03/16] Support chzained comparisonc`1 < 2 < 3` (#581) --- expr_test.go | 20 +++++++++++++ parser/operator/operator.go | 4 +++ parser/parser.go | 36 ++++++++++++++++++++++ parser/parser_test.go | 60 +++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+) diff --git a/expr_test.go b/expr_test.go index ea9213be2..a4321c575 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1253,6 +1253,26 @@ func TestExpr(t *testing.T) { `[nil, 3, 4]?.[0]?.[1]`, nil, }, + { + `1 > 2 < 3`, + false, + }, + { + `1 < 2 < 3`, + true, + }, + { + `1 < 2 < 3 > 4`, + false, + }, + { + `1 < 2 < 3 > 2`, + true, + }, + { + `1 < 2 < 3 == true`, + true, + }, } for _, tt := range tests { diff --git a/parser/operator/operator.go b/parser/operator/operator.go index 8d804c7b3..411a0e2bc 100644 --- a/parser/operator/operator.go +++ b/parser/operator/operator.go @@ -54,3 +54,7 @@ var Binary = map[string]Operator{ "^": {100, Right}, "??": {500, Left}, } + +func IsComparison(op string) bool { + return op == "<" || op == ">" || op == ">=" || op == "<=" +} diff --git a/parser/parser.go b/parser/parser.go index 1eabdebe2..9114bc0c9 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -164,6 +164,11 @@ func (p *parser) parseExpression(precedence int) Node { break } + if operator.IsComparison(opToken.Value) { + nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) + goto next + } + var nodeRight Node if op.Associativity == operator.Left { nodeRight = p.parseExpression(op.Precedence + 1) @@ -685,3 +690,34 @@ func (p *parser) parsePostfixExpression(node Node) Node { } return node } + +func (p *parser) parseComparison(left Node, token Token, precedence int) Node { + var rootNode Node + for { + comparator := p.parseExpression(precedence + 1) + cmpNode := &BinaryNode{ + Operator: token.Value, + Left: left, + Right: comparator, + } + cmpNode.SetLocation(token.Location) + if rootNode == nil { + rootNode = cmpNode + } else { + rootNode = &BinaryNode{ + Operator: "&&", + Left: rootNode, + Right: cmpNode, + } + rootNode.SetLocation(token.Location) + } + + left = comparator + token = p.current + if !(token.Is(Operator) && operator.IsComparison(token.Value) && p.err == nil) { + break + } + p.next() + } + return rootNode +} diff --git a/parser/parser_test.go b/parser/parser_test.go index b633bd52e..9225e1028 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -531,6 +531,66 @@ world`}, To: &IntegerNode{Value: 3}, }, }, + { + `1 < 2 > 3`, + &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: ">", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + }, + { + `1 < 2 < 3 < 4`, + &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 3}, + Right: &IntegerNode{Value: 4}, + }, + }, + }, + { + `1 < 2 < 3 == true`, + &BinaryNode{ + Operator: "==", + Left: &BinaryNode{ + Operator: "&&", + Left: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 2}, + }, + Right: &BinaryNode{ + Operator: "<", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 3}, + }, + }, + Right: &BoolNode{Value: true}, + }, + }, } for _, test := range tests { t.Run(test.input, func(t *testing.T) { From 355fb28bafa40cbf11ea3cee95bb102c1a8ac64d Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Thu, 29 Feb 2024 20:51:34 +0100 Subject: [PATCH 04/16] Add spans --- compiler/compiler.go | 29 +++++++++++++++++++++++++++++ conf/config.go | 1 + vm/opcodes.go | 2 ++ vm/program.go | 9 +++++++++ vm/utils.go | 13 +++++++++++++ vm/vm.go | 9 +++++++++ 6 files changed, 63 insertions(+) diff --git a/compiler/compiler.go b/compiler/compiler.go index a4f189e6b..808b53c9b 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -50,6 +50,11 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro } } + var span *Span + if len(c.spans) > 0 { + span = c.spans[0] + } + program = NewProgram( tree.Source, tree.Node, @@ -60,6 +65,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro c.arguments, c.functions, c.debugInfo, + span, ) return } @@ -76,6 +82,7 @@ type compiler struct { functionsIndex map[string]int debugInfo map[string]string nodes []ast.Node + spans []*Span chains [][]int arguments []int } @@ -193,6 +200,28 @@ func (c *compiler) compile(node ast.Node) { c.nodes = c.nodes[:len(c.nodes)-1] }() + if c.config != nil && c.config.Profile { + span := &Span{ + Name: reflect.TypeOf(node).String(), + Expression: node.String(), + } + if len(c.spans) > 0 { + prev := c.spans[len(c.spans)-1] + prev.Children = append(prev.Children, span) + } + c.spans = append(c.spans, span) + defer func() { + if len(c.spans) > 1 { + c.spans = c.spans[:len(c.spans)-1] + } + }() + + c.emit(OpProfileStart, c.addConstant(span)) + defer func() { + c.emit(OpProfileEnd, c.addConstant(span)) + }() + } + switch n := node.(type) { case *ast.NilNode: c.NilNode(n) diff --git a/conf/config.go b/conf/config.go index e543732ce..799898109 100644 --- a/conf/config.go +++ b/conf/config.go @@ -20,6 +20,7 @@ type Config struct { ExpectAny bool Optimize bool Strict bool + Profile bool ConstFns map[string]reflect.Value Visitors []ast.Visitor Functions FunctionsTable diff --git a/vm/opcodes.go b/vm/opcodes.go index 0417dab61..84d751d6b 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -81,6 +81,8 @@ const ( OpGroupBy OpSortBy OpSort + OpProfileStart + OpProfileEnd OpBegin OpEnd // This opcode must be at the end of this list. ) diff --git a/vm/program.go b/vm/program.go index 4a878267b..989546744 100644 --- a/vm/program.go +++ b/vm/program.go @@ -27,6 +27,7 @@ type Program struct { variables int functions []Function debugInfo map[string]string + span *Span } // NewProgram returns a new Program. It's used by the compiler. @@ -40,6 +41,7 @@ func NewProgram( arguments []int, functions []Function, debugInfo map[string]string, + span *Span, ) *Program { return &Program{ source: source, @@ -51,6 +53,7 @@ func NewProgram( Arguments: arguments, functions: functions, debugInfo: debugInfo, + span: span, } } @@ -360,6 +363,12 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpSort: code("OpSort") + case OpProfileStart: + code("OpProfileStart") + + case OpProfileEnd: + code("OpProfileEnd") + case OpBegin: code("OpBegin") diff --git a/vm/utils.go b/vm/utils.go index d7db2a52a..fc2f5e7b8 100644 --- a/vm/utils.go +++ b/vm/utils.go @@ -2,6 +2,7 @@ package vm import ( "reflect" + "time" ) type ( @@ -25,3 +26,15 @@ type Scope struct { } type groupBy = map[any][]any + +type Span struct { + Name string `json:"name"` + Expression string `json:"expression"` + Duration int64 `json:"duration"` + Children []*Span `json:"children"` + start time.Time +} + +func GetSpan(program *Program) *Span { + return program.span +} diff --git a/vm/vm.go b/vm/vm.go index 1e85893b0..7e933ce74 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -8,6 +8,7 @@ import ( "regexp" "sort" "strings" + "time" "github.com/expr-lang/expr/builtin" "github.com/expr-lang/expr/file" @@ -523,6 +524,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { vm.memGrow(uint(scope.Len)) vm.push(sortable.Array) + case OpProfileStart: + span := program.Constants[arg].(*Span) + span.start = time.Now() + + case OpProfileEnd: + span := program.Constants[arg].(*Span) + span.Duration += time.Since(span.start).Nanoseconds() + case OpBegin: a := vm.pop() array := reflect.ValueOf(a) From ded019d21ec180b7018641732617c59e8932da32 Mon Sep 17 00:00:00 2001 From: Ganesan Karuppasamy Date: Sun, 3 Mar 2024 21:14:32 +0530 Subject: [PATCH 05/16] Enable Support for Arrays in Sum, Mean, and Median Functions (#580) --- builtin/builtin.go | 201 ++++++---------------------------------- builtin/builtin_test.go | 13 +++ builtin/lib.go | 154 ++++++++++++++++++++++++------ builtin/validation.go | 38 ++++++++ 4 files changed, 206 insertions(+), 200 deletions(-) create mode 100644 builtin/validation.go diff --git a/builtin/builtin.go b/builtin/builtin.go index fc48e111a..7bf377df2 100644 --- a/builtin/builtin.go +++ b/builtin/builtin.go @@ -135,42 +135,21 @@ var Builtins = []*Function{ Name: "ceil", Fast: Ceil, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for ceil (type %s)", args[0]) + return validateRoundFunc("ceil", args) }, }, { Name: "floor", Fast: Floor, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("floor", args) }, }, { Name: "round", Fast: Round, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: - return floatType, nil - } - return anyType, fmt.Errorf("invalid argument for floor (type %s)", args[0]) + return validateRoundFunc("round", args) }, }, { @@ -392,185 +371,63 @@ var Builtins = []*Function{ }, { Name: "max", - Func: Max, + Func: func(args ...any) (any, error) { + return minMax("max", runtime.Less, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call max") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for max (type %s)", arg) - } - } - return args[0], nil - } + return validateAggregateFunc("max", args) }, }, { Name: "min", - Func: Min, + Func: func(args ...any) (any, error) { + return minMax("min", runtime.More, args...) + }, Validate: func(args []reflect.Type) (reflect.Type, error) { - switch len(args) { - case 0: - return anyType, fmt.Errorf("not enough arguments to call min") - case 1: - if kindName := kind(args[0]); kindName == reflect.Array || kindName == reflect.Slice { - return anyType, nil - } - fallthrough - default: - for _, arg := range args { - switch kind(arg) { - case reflect.Interface, reflect.Array, reflect.Slice: - return anyType, nil - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: - default: - return anyType, fmt.Errorf("invalid argument for min (type %s)", arg) - } - } - return args[0], nil - - } + return validateAggregateFunc("min", args) }, }, { Name: "sum", - Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot sum %s", v.Kind()) - } - sum := int64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += it.Int() - } else if it.CanFloat() { - goto float - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return int(sum), nil - float: - fSum := float64(sum) - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - fSum += float64(it.Int()) - } else if it.CanFloat() { - fSum += it.Float() - } else { - return nil, fmt.Errorf("cannot sum %s", it.Kind()) - } - } - return fSum, nil - }, + Func: sum, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot sum %s", args[0]) - } - return anyType, nil + return validateAggregateFunc("sum", args) }, }, { Name: "mean", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot mean %s", v.Kind()) + count, sum, err := mean(args...) + if err != nil { + return nil, err } - if v.Len() == 0 { + if count == 0 { return 0.0, nil } - sum := float64(0) - i := 0 - for ; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - sum += float64(it.Int()) - } else if it.CanFloat() { - sum += it.Float() - } else { - return nil, fmt.Errorf("cannot mean %s", it.Kind()) - } - } - return sum / float64(i), nil + return sum / float64(count), nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot avg %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("mean", args) }, }, { Name: "median", Func: func(args ...any) (any, error) { - if len(args) != 1 { - return nil, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - v := reflect.ValueOf(args[0]) - if v.Kind() != reflect.Slice && v.Kind() != reflect.Array { - return nil, fmt.Errorf("cannot median %s", v.Kind()) - } - if v.Len() == 0 { - return 0.0, nil + values, err := median(args...) + if err != nil { + return nil, err } - s := make([]float64, v.Len()) - for i := 0; i < v.Len(); i++ { - it := deref.Value(v.Index(i)) - if it.CanInt() { - s[i] = float64(it.Int()) - } else if it.CanFloat() { - s[i] = it.Float() - } else { - return nil, fmt.Errorf("cannot median %s", it.Kind()) + if n := len(values); n > 0 { + sort.Float64s(values) + if n%2 == 1 { + return values[n/2], nil } + return (values[n/2-1] + values[n/2]) / 2, nil } - sort.Float64s(s) - if len(s)%2 == 0 { - return (s[len(s)/2-1] + s[len(s)/2]) / 2, nil - } - return s[len(s)/2], nil + return 0.0, nil }, Validate: func(args []reflect.Type) (reflect.Type, error) { - if len(args) != 1 { - return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) - } - switch kind(args[0]) { - case reflect.Interface, reflect.Slice, reflect.Array: - default: - return anyType, fmt.Errorf("cannot median %s", args[0]) - } - return floatType, nil + return validateAggregateFunc("median", args) }, }, { diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index bc1a2e149..aa324c9be 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -85,19 +85,29 @@ func TestBuiltin(t *testing.T) { {`min(1.5, 2.5, 3.5)`, 1.5}, {`min([1, 2, 3])`, 1}, {`min([1.5, 2.5, 3.5])`, 1.5}, + {`min(-1, [1.5, 2.5, 3.5])`, -1}, {`sum(1..9)`, 45}, {`sum([.5, 1.5, 2.5])`, 4.5}, {`sum([])`, 0}, {`sum([1, 2, 3.0, 4])`, 10.0}, + {`sum(10, [1, 2, 3], 1..9)`, 61}, + {`sum(-10, [1, 2, 3, 4])`, 0}, + {`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1}, {`mean(1..9)`, 5.0}, {`mean([.5, 1.5, 2.5])`, 1.5}, {`mean([])`, 0.0}, {`mean([1, 2, 3.0, 4])`, 2.5}, + {`mean(10, [1, 2, 3], 1..9)`, 4.6923076923076925}, + {`mean(-10, [1, 2, 3, 4])`, 0.0}, + {`mean(10.9, 1..9)`, 5.59}, {`median(1..9)`, 5.0}, {`median([.5, 1.5, 2.5])`, 1.5}, {`median([])`, 0.0}, {`median([1, 2, 3])`, 2.0}, {`median([1, 2, 3, 4])`, 2.5}, + {`median(10, [1, 2, 3], 1..9)`, 4.0}, + {`median(-10, [1, 2, 3, 4])`, 2.0}, + {`median(1..5, 4.9)`, 3.5}, {`toJSON({foo: 1, bar: 2})`, "{\n \"bar\": 2,\n \"foo\": 1\n}"}, {`fromJSON("[1, 2, 3]")`, []any{1.0, 2.0, 3.0}}, {`toBase64("hello")`, "aGVsbG8="}, @@ -207,6 +217,9 @@ func TestBuiltin_errors(t *testing.T) { {`min()`, `not enough arguments to call min`}, {`min(1, "2")`, `invalid argument for min (type string)`}, {`min([1, "2"])`, `invalid argument for min (type string)`}, + {`median(1..9, "t")`, "invalid argument for median (type string)"}, + {`mean("s", 1..9)`, "invalid argument for mean (type string)"}, + {`sum("s", "h")`, "invalid argument for sum (type string)"}, {`duration("error")`, `invalid duration`}, {`date("error")`, `invalid date`}, {`get()`, `invalid number of arguments (expected 2, got 0)`}, diff --git a/builtin/lib.go b/builtin/lib.go index b08c2ed2b..9ff9478aa 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -6,7 +6,7 @@ import ( "reflect" "strconv" - "github.com/expr-lang/expr/vm/runtime" + "github.com/expr-lang/expr/internal/deref" ) func Len(x any) any { @@ -254,45 +254,143 @@ func String(arg any) any { return fmt.Sprintf("%v", arg) } -func Max(args ...any) (any, error) { - return minMaxFunc("max", runtime.Less, args) -} +func sum(args ...any) (any, error) { + var total int + var fTotal float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) -func Min(args ...any) (any, error) { - return minMaxFunc("min", runtime.More, args) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemSum, err := sum(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemSum := elemSum.(type) { + case int: + total += elemSum + case float64: + fTotal += elemSum + } + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += int(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += int(rv.Uint()) + case reflect.Float32, reflect.Float64: + fTotal += rv.Float() + default: + return nil, fmt.Errorf("invalid argument for sum (type %T)", arg) + } + } + + if fTotal != 0.0 { + return fTotal + float64(total), nil + } + return total, nil } -func minMaxFunc(name string, fn func(any, any) bool, args []any) (any, error) { +func minMax(name string, fn func(any, any) bool, args ...any) (any, error) { var val any for _, arg := range args { - switch v := arg.(type) { - case []float32, []float64, []uint, []uint8, []uint16, []uint32, []uint64, []int, []int8, []int16, []int32, []int64: - rv := reflect.ValueOf(v) - if rv.Len() == 0 { - return nil, fmt.Errorf("not enough arguments to call %s", name) - } - arg = rv.Index(0).Interface() - for i := 1; i < rv.Len(); i++ { - elem := rv.Index(i).Interface() - if fn(arg, elem) { - arg = elem + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemVal, err := minMax(name, fn, rv.Index(i).Interface()) + if err != nil { + return nil, err + } + switch elemVal.(type) { + case int, int8, int16, int32, int64, + uint, uint8, uint16, uint32, uint64, + float32, float64: + if elemVal != nil && (val == nil || fn(val, elemVal)) { + val = elemVal + } + default: + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, elemVal) } + } - case []any: - var err error - if arg, err = minMaxFunc(name, fn, v); err != nil { - return nil, err + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + elemVal := rv.Interface() + if val == nil || fn(val, elemVal) { + val = elemVal } - case float32, float64, uint, uint8, uint16, uint32, uint64, int, int8, int16, int32, int64: default: if len(args) == 1 { - return arg, nil + return args[0], nil } - return nil, fmt.Errorf("invalid argument for %s (type %T)", name, v) - } - if val == nil || fn(val, arg) { - val = arg + return nil, fmt.Errorf("invalid argument for %s (type %T)", name, arg) } } return val, nil } + +func mean(args ...any) (int, float64, error) { + var total float64 + var count int + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elemCount, elemSum, err := mean(rv.Index(i).Interface()) + if err != nil { + return 0, 0, err + } + total += elemSum + count += elemCount + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + total += float64(rv.Int()) + count++ + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + total += float64(rv.Uint()) + count++ + case reflect.Float32, reflect.Float64: + total += rv.Float() + count++ + default: + return 0, 0, fmt.Errorf("invalid argument for mean (type %T)", arg) + } + } + return count, total, nil +} + +func median(args ...any) ([]float64, error) { + var values []float64 + + for _, arg := range args { + rv := reflect.ValueOf(deref.Deref(arg)) + switch rv.Kind() { + case reflect.Array, reflect.Slice: + size := rv.Len() + for i := 0; i < size; i++ { + elems, err := median(rv.Index(i).Interface()) + if err != nil { + return nil, err + } + values = append(values, elems...) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + values = append(values, float64(rv.Int())) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + values = append(values, float64(rv.Uint())) + case reflect.Float32, reflect.Float64: + values = append(values, rv.Float()) + default: + return nil, fmt.Errorf("invalid argument for median (type %T)", arg) + } + } + return values, nil +} diff --git a/builtin/validation.go b/builtin/validation.go new file mode 100644 index 000000000..057f247e9 --- /dev/null +++ b/builtin/validation.go @@ -0,0 +1,38 @@ +package builtin + +import ( + "fmt" + "reflect" + + "github.com/expr-lang/expr/internal/deref" +) + +func validateAggregateFunc(name string, args []reflect.Type) (reflect.Type, error) { + switch len(args) { + case 0: + return anyType, fmt.Errorf("not enough arguments to call %s", name) + default: + for _, arg := range args { + switch kind(deref.Type(arg)) { + case reflect.Interface, reflect.Array, reflect.Slice: + return anyType, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64: + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, arg) + } + } + return args[0], nil + } +} + +func validateRoundFunc(name string, args []reflect.Type) (reflect.Type, error) { + if len(args) != 1 { + return anyType, fmt.Errorf("invalid number of arguments (expected 1, got %d)", len(args)) + } + switch kind(args[0]) { + case reflect.Float32, reflect.Float64, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Interface: + return floatType, nil + default: + return anyType, fmt.Errorf("invalid argument for %s (type %s)", name, args[0]) + } +} From 0e9136ef319565b80ca63ef4b35780d329d0dfff Mon Sep 17 00:00:00 2001 From: Ganesan Karuppasamy Date: Mon, 4 Mar 2024 19:10:33 +0530 Subject: [PATCH 06/16] Fix `-1 not in []` expressions (#590) --- compiler/compiler_test.go | 33 +++++++++++++ expr_test.go | 8 +++ parser/operator/operator.go | 9 ++++ parser/parser.go | 99 ++++++++++++++++++++----------------- parser/parser_test.go | 61 +++++++++++++++++++++++ 5 files changed, 164 insertions(+), 46 deletions(-) diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 741142a77..fbd83ec86 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -541,6 +541,39 @@ func TestCompile_optimizes_jumps(t *testing.T) { {vm.OpFetch, 0}, }, }, + { + `-1 not in [1, 2, 5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, + { + `1 + 8 not in [1, 2, 5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, + { + `true ? false : 8 not in [1, 2, 5]`, + []op{ + {vm.OpTrue, 0}, + {vm.OpJumpIfFalse, 3}, + {vm.OpPop, 0}, + {vm.OpFalse, 0}, + {vm.OpJump, 5}, + {vm.OpPop, 0}, + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpIn, 0}, + {vm.OpNot, 0}, + }, + }, } for _, test := range tests { diff --git a/expr_test.go b/expr_test.go index a4321c575..46cb8fe89 100644 --- a/expr_test.go +++ b/expr_test.go @@ -785,6 +785,10 @@ func TestExpr(t *testing.T) { `Two not in 0..1`, true, }, + { + `-1 not in [1]`, + true, + }, { `Int32 in [10, 20]`, false, @@ -797,6 +801,10 @@ func TestExpr(t *testing.T) { `String matches ("^" + String + "$")`, true, }, + { + `'foo' + 'bar' not matches 'foobar'`, + false, + }, { `"foobar" contains "bar"`, true, diff --git a/parser/operator/operator.go b/parser/operator/operator.go index 411a0e2bc..4eeaf80ed 100644 --- a/parser/operator/operator.go +++ b/parser/operator/operator.go @@ -20,6 +20,15 @@ func IsBoolean(op string) bool { return op == "and" || op == "or" || op == "&&" || op == "||" } +func AllowedNegateSuffix(op string) bool { + switch op { + case "contains", "matches", "startsWith", "endsWith", "in": + return true + default: + return false + } +} + var Unary = map[string]Operator{ "not": {50, Left}, "!": {50, Left}, diff --git a/parser/parser.go b/parser/parser.go index 9114bc0c9..9cb79cbbb 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -126,10 +126,8 @@ func (p *parser) expect(kind Kind, values ...string) { // parse functions func (p *parser) parseExpression(precedence int) Node { - if precedence == 0 { - if p.current.Is(Operator, "let") { - return p.parseVariableDeclaration() - } + if precedence == 0 && p.current.Is(Operator, "let") { + return p.parseVariableDeclaration() } nodeLeft := p.parsePrimary() @@ -137,62 +135,71 @@ func (p *parser) parseExpression(precedence int) Node { prevOperator := "" opToken := p.current for opToken.Is(Operator) && p.err == nil { - negate := false + negate := opToken.Is(Operator, "not") var notToken Token // Handle "not *" operator, like "not in" or "not contains". - if opToken.Is(Operator, "not") { + if negate { + currentPos := p.pos p.next() - notToken = p.current - negate = true - opToken = p.current + if operator.AllowedNegateSuffix(p.current.Value) { + if op, ok := operator.Binary[p.current.Value]; ok && op.Precedence >= precedence { + notToken = p.current + opToken = p.current + } else { + p.pos = currentPos + p.current = opToken + break + } + } else { + p.error("unexpected token %v", p.current) + break + } } - if op, ok := operator.Binary[opToken.Value]; ok { - if op.Precedence >= precedence { - p.next() + if op, ok := operator.Binary[opToken.Value]; ok && op.Precedence >= precedence { + p.next() - if opToken.Value == "|" { - identToken := p.current - p.expect(Identifier) - nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) - goto next - } + if opToken.Value == "|" { + identToken := p.current + p.expect(Identifier) + nodeLeft = p.parseCall(identToken, []Node{nodeLeft}, true) + goto next + } - if prevOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { - p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) - break - } + if prevOperator == "??" && opToken.Value != "??" && !opToken.Is(Bracket, "(") { + p.errorAt(opToken, "Operator (%v) and coalesce expressions (??) cannot be mixed. Wrap either by parentheses.", opToken.Value) + break + } - if operator.IsComparison(opToken.Value) { - nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) - goto next - } + if operator.IsComparison(opToken.Value) { + nodeLeft = p.parseComparison(nodeLeft, opToken, op.Precedence) + goto next + } - var nodeRight Node - if op.Associativity == operator.Left { - nodeRight = p.parseExpression(op.Precedence + 1) - } else { - nodeRight = p.parseExpression(op.Precedence) - } + var nodeRight Node + if op.Associativity == operator.Left { + nodeRight = p.parseExpression(op.Precedence + 1) + } else { + nodeRight = p.parseExpression(op.Precedence) + } - nodeLeft = &BinaryNode{ - Operator: opToken.Value, - Left: nodeLeft, - Right: nodeRight, - } - nodeLeft.SetLocation(opToken.Location) + nodeLeft = &BinaryNode{ + Operator: opToken.Value, + Left: nodeLeft, + Right: nodeRight, + } + nodeLeft.SetLocation(opToken.Location) - if negate { - nodeLeft = &UnaryNode{ - Operator: "not", - Node: nodeLeft, - } - nodeLeft.SetLocation(notToken.Location) + if negate { + nodeLeft = &UnaryNode{ + Operator: "not", + Node: nodeLeft, } - - goto next + nodeLeft.SetLocation(notToken.Location) } + + goto next } break diff --git a/parser/parser_test.go b/parser/parser_test.go index 9225e1028..2a30787a0 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -365,6 +365,62 @@ world`}, &UnaryNode{Operator: "not", Node: &IdentifierNode{Value: "in_var"}}, }, + { + "-1 not in [1, 2, 3, 4]", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "in", + Left: &UnaryNode{Operator: "-", Node: &IntegerNode{Value: 1}}, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, + }, + { + "1*8 not in [1, 2, 3, 4]", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "in", + Left: &BinaryNode{Operator: "*", + Left: &IntegerNode{Value: 1}, + Right: &IntegerNode{Value: 8}, + }, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 3}, + &IntegerNode{Value: 4}, + }}}}, + }, + { + "2==2 ? false : 3 not in [1, 2, 5]", + &ConditionalNode{ + Cond: &BinaryNode{ + Operator: "==", + Left: &IntegerNode{Value: 2}, + Right: &IntegerNode{Value: 2}, + }, + Exp1: &BoolNode{Value: false}, + Exp2: &UnaryNode{ + Operator: "not", + Node: &BinaryNode{ + Operator: "in", + Left: &IntegerNode{Value: 3}, + Right: &ArrayNode{Nodes: []Node{ + &IntegerNode{Value: 1}, + &IntegerNode{Value: 2}, + &IntegerNode{Value: 5}, + }}}}}, + }, + { + "'foo' + 'bar' not matches 'foobar'", + &UnaryNode{Operator: "not", + Node: &BinaryNode{Operator: "matches", + Left: &BinaryNode{Operator: "+", + Left: &StringNode{Value: "foo"}, + Right: &StringNode{Value: "bar"}}, + Right: &StringNode{Value: "foobar"}}}, + }, { "all(Tickets, #)", &BuiltinNode{ @@ -706,6 +762,11 @@ invalid float literal: strconv.ParseFloat: parsing "0o1E+1": invalid syntax (1:6 invalid float literal: strconv.ParseFloat: parsing "1E": invalid syntax (1:2) | 1E | .^ + +1 not == [1, 2, 5] +unexpected token Operator("==") (1:7) + | 1 not == [1, 2, 5] + | ......^ ` func TestParse_error(t *testing.T) { From 5708180cb26fa2e6d25be8d1a59e40ab5dc8dd3e Mon Sep 17 00:00:00 2001 From: Sergey Date: Sun, 17 Mar 2024 14:51:53 +0700 Subject: [PATCH 07/16] `expr.Operator` passes before `expr.Env` caused error (#606) --- checker/checker_test.go | 2 +- conf/config.go | 6 +++++- expr_test.go | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/checker/checker_test.go b/checker/checker_test.go index d6a84abc5..29c50807e 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -632,7 +632,7 @@ func TestCheck_TaggedFieldName(t *testing.T) { tree, err := parser.Parse(`foo.bar`) require.NoError(t, err) - config := &conf.Config{} + config := conf.CreateNew() expr.Env(struct { x struct { y bool `expr:"bar"` diff --git a/conf/config.go b/conf/config.go index 799898109..01a407a10 100644 --- a/conf/config.go +++ b/conf/config.go @@ -32,6 +32,7 @@ type Config struct { func CreateNew() *Config { c := &Config{ Optimize: true, + Types: make(TypesTable), ConstFns: make(map[string]reflect.Value), Functions: make(map[string]*builtin.Function), Builtins: make(map[string]*builtin.Function), @@ -62,7 +63,10 @@ func (c *Config) WithEnv(env any) { } c.Env = env - c.Types = CreateTypesTable(env) + types := CreateTypesTable(env) + for name, t := range types { + c.Types[name] = t + } c.MapEnv = mapEnv c.DefaultType = mapValueType c.Strict = true diff --git a/expr_test.go b/expr_test.go index 46cb8fe89..790fdd5d9 100644 --- a/expr_test.go +++ b/expr_test.go @@ -2511,6 +2511,20 @@ func TestRaceCondition_variables(t *testing.T) { wg.Wait() } +func TestOperatorDependsOnEnv(t *testing.T) { + env := map[string]any{ + "plus": func(a, b int) int { + return 42 + }, + } + program, err := expr.Compile(`1 + 2`, expr.Operator("+", "plus"), expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, 42, out) +} + func TestArrayComparison(t *testing.T) { tests := []struct { env any From 9748859e10fef2bd8a6c3e07bfb54b95ac81631b Mon Sep 17 00:00:00 2001 From: Sergey Date: Thu, 21 Mar 2024 02:06:04 +0700 Subject: [PATCH 08/16] builtin `int` unwraps underlying int value (#611) --- builtin/builtin_test.go | 14 ++++++++++++++ builtin/lib.go | 4 ++++ 2 files changed, 18 insertions(+) diff --git a/builtin/builtin_test.go b/builtin/builtin_test.go index aa324c9be..7f5045f41 100644 --- a/builtin/builtin_test.go +++ b/builtin/builtin_test.go @@ -612,3 +612,17 @@ func TestBuiltin_bitOpsFunc(t *testing.T) { }) } } + +type customInt int + +func Test_int_unwraps_underlying_value(t *testing.T) { + env := map[string]any{ + "customInt": customInt(42), + } + program, err := expr.Compile(`int(customInt) == 42`, expr.Env(env)) + require.NoError(t, err) + + out, err := expr.Run(program, env) + require.NoError(t, err) + assert.Equal(t, true, out) +} diff --git a/builtin/lib.go b/builtin/lib.go index 9ff9478aa..e3a6c0aef 100644 --- a/builtin/lib.go +++ b/builtin/lib.go @@ -209,6 +209,10 @@ func Int(x any) any { } return i default: + val := reflect.ValueOf(x) + if val.CanConvert(integerType) { + return val.Convert(integerType).Interface() + } panic(fmt.Sprintf("invalid operation: int(%T)", x)) } } From fcebdadc96f31afb134e2f31d056418e5a5a0f74 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Wed, 27 Mar 2024 09:26:56 +0100 Subject: [PATCH 09/16] Better map ast printing --- ast/print.go | 8 +++++++- ast/print_test.go | 5 +++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ast/print.go b/ast/print.go index fa593ae28..063e9eb27 100644 --- a/ast/print.go +++ b/ast/print.go @@ -202,5 +202,11 @@ func (n *MapNode) String() string { } func (n *PairNode) String() string { - return fmt.Sprintf("%s: %s", n.Key.String(), n.Value.String()) + if str, ok := n.Key.(*StringNode); ok { + if utils.IsValidIdentifier(str.Value) { + return fmt.Sprintf("%s: %s", str.Value, n.Value.String()) + } + return fmt.Sprintf("%q: %s", str.String(), n.Value.String()) + } + return fmt.Sprintf("(%s): %s", n.Key.String(), n.Value.String()) } diff --git a/ast/print_test.go b/ast/print_test.go index 16d64357b..d9e55c2ea 100644 --- a/ast/print_test.go +++ b/ast/print_test.go @@ -55,8 +55,8 @@ func TestPrint(t *testing.T) { {`func(a)`, `func(a)`}, {`func(a, b)`, `func(a, b)`}, {`{}`, `{}`}, - {`{a: b}`, `{"a": b}`}, - {`{a: b, c: d}`, `{"a": b, "c": d}`}, + {`{a: b}`, `{a: b}`}, + {`{a: b, c: d}`, `{a: b, c: d}`}, {`[]`, `[]`}, {`[a]`, `[a]`}, {`[a, b]`, `[a, b]`}, @@ -71,6 +71,7 @@ func TestPrint(t *testing.T) { {`a[1:]`, `a[1:]`}, {`a[:]`, `a[:]`}, {`(nil ?? 1) > 0`, `(nil ?? 1) > 0`}, + {`{("a" + "b"): 42}`, `{("a" + "b"): 42}`}, } for _, tt := range tests { From 51156fa1ff9782b0318c0c38b1c4ff0e0a651f5b Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Thu, 28 Mar 2024 21:37:51 +0800 Subject: [PATCH 10/16] feat: extract code for compiling equal operator (#614) --- compiler/compiler.go | 48 ++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/compiler/compiler.go b/compiler/compiler.go index 808b53c9b..a38d977d5 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -395,34 +395,12 @@ func (c *compiler) UnaryNode(node *ast.UnaryNode) { } func (c *compiler) BinaryNode(node *ast.BinaryNode) { - l := kind(node.Left) - r := kind(node.Right) - - leftIsSimple := isSimpleType(node.Left) - rightIsSimple := isSimpleType(node.Right) - leftAndRightAreSimple := leftIsSimple && rightIsSimple - switch node.Operator { case "==": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - - if l == r && l == reflect.Int && leftAndRightAreSimple { - c.emit(OpEqualInt) - } else if l == r && l == reflect.String && leftAndRightAreSimple { - c.emit(OpEqualString) - } else { - c.emit(OpEqual) - } + c.equalBinaryNode(node) case "!=": - c.compile(node.Left) - c.derefInNeeded(node.Left) - c.compile(node.Right) - c.derefInNeeded(node.Right) - c.emit(OpEqual) + c.equalBinaryNode(node) c.emit(OpNot) case "or", "||": @@ -580,6 +558,28 @@ func (c *compiler) BinaryNode(node *ast.BinaryNode) { } } +func (c *compiler) equalBinaryNode(node *ast.BinaryNode) { + l := kind(node.Left) + r := kind(node.Right) + + leftIsSimple := isSimpleType(node.Left) + rightIsSimple := isSimpleType(node.Right) + leftAndRightAreSimple := leftIsSimple && rightIsSimple + + c.compile(node.Left) + c.derefInNeeded(node.Left) + c.compile(node.Right) + c.derefInNeeded(node.Right) + + if l == r && l == reflect.Int && leftAndRightAreSimple { + c.emit(OpEqualInt) + } else if l == r && l == reflect.String && leftAndRightAreSimple { + c.emit(OpEqualString) + } else { + c.emit(OpEqual) + } +} + func isSimpleType(node ast.Node) bool { if node == nil { return false From 32579115c7bdf1eaa416d1b9531dd8cea4a55a64 Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Mon, 8 Apr 2024 09:00:54 +0200 Subject: [PATCH 11/16] Update README.md (#619) Add SPAN Digital as an user of expr --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index bd34c7d24..1475fe2f5 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,7 @@ func main() { * [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm. * [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows. * [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling. +* [SPAN Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products [Add your company too](https://github.com/expr-lang/expr/edit/master/README.md) From e4f0e78896f5e5b99c567b83a546fabaeef9abf3 Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Tue, 9 Apr 2024 09:47:41 +0200 Subject: [PATCH 12/16] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1475fe2f5..1a2d7dc83 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,7 @@ func main() { * [Visually.io](https://visually.io) employs Expr as a business rule engine for its personalization targeting algorithm. * [Akvorado](https://github.com/akvorado/akvorado) utilizes Expr to classify exporters and interfaces in network flows. * [keda.sh](https://keda.sh) uses Expr to allow customization of its Kubernetes-based event-driven autoscaling. -* [SPAN Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products +* [Span Digital](https://spandigital.com/) uses Expr in it's Knowledge Management products. [Add your company too](https://github.com/expr-lang/expr/edit/master/README.md) From 4454efff153c868f9bbe3fdea362b39023298eb9 Mon Sep 17 00:00:00 2001 From: needsure <166317845+needsure@users.noreply.github.com> Date: Tue, 9 Apr 2024 19:13:55 +0800 Subject: [PATCH 13/16] chore: fix some typos in conments (#622) Signed-off-by: needsure --- patcher/value/value.go | 4 ++-- test/operator/operator_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/patcher/value/value.go b/patcher/value/value.go index 59351be6b..28f52be27 100644 --- a/patcher/value/value.go +++ b/patcher/value/value.go @@ -13,9 +13,9 @@ import ( // ValueGetter is a Patcher that allows custom types to be represented as standard go values for use with expr. // It also adds the `$patcher_value_getter` function to the program for efficiently calling matching interfaces. // -// The purpose of this Patcher is to make it seemless to use custom types in expressions without the need to +// The purpose of this Patcher is to make it seamless to use custom types in expressions without the need to // first convert them to standard go values. It may also facilitate using already existing structs or maps as -// environments when they contain compatabile types. +// environments when they contain compatible types. // // An example usage may be modeling a database record with columns that have varying data types and constraints. // In such an example you may have custom types that, beyond storing a simple value, such as an integer, may diff --git a/test/operator/operator_test.go b/test/operator/operator_test.go index a19c191dc..b49d91cc6 100644 --- a/test/operator/operator_test.go +++ b/test/operator/operator_test.go @@ -77,7 +77,7 @@ func TestOperator_Function(t *testing.T) { } for _, tt := range tests { - t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) { + t.Run(fmt.Sprintf(`operator function helper test %s`, tt.input), func(t *testing.T) { program, err := expr.Compile( tt.input, expr.Env(env), From 582fb32d9214f6ef7a4dfcf754cb466787e15a7c Mon Sep 17 00:00:00 2001 From: Anton Medvedev Date: Wed, 10 Apr 2024 10:58:17 +0200 Subject: [PATCH 14/16] Revert "Optimize boolean operations between all, any, one, none functions (#555)" (#625) This reverts commit 3c03e5965172519f7bc12100db6607d6a9fae031. --- optimizer/optimizer.go | 1 - optimizer/optimizer_test.go | 122 ----------------------------- optimizer/predicate_combination.go | 51 ------------ 3 files changed, 174 deletions(-) delete mode 100644 optimizer/predicate_combination.go diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 6d1fb0b54..a9c0fa3d3 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -36,6 +36,5 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLen{}) Walk(node, &filterLast{}) Walk(node, &filterFirst{}) - Walk(node, &predicateCombination{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index 703bd1ceb..e45de763b 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -1,7 +1,6 @@ package optimizer_test import ( - "fmt" "reflect" "strings" "testing" @@ -340,124 +339,3 @@ func TestOptimize_filter_map_first(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } - -func TestOptimize_predicate_combination(t *testing.T) { - tests := []struct { - op string - fn string - wantOp string - }{ - {"and", "all", "and"}, - {"&&", "all", "&&"}, - {"or", "all", "or"}, - {"||", "all", "||"}, - {"and", "any", "and"}, - {"&&", "any", "&&"}, - {"or", "any", "or"}, - {"||", "any", "||"}, - {"and", "none", "or"}, - {"&&", "none", "||"}, - {"and", "one", "or"}, - {"&&", "one", "||"}, - } - - for _, tt := range tests { - rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn) - t.Run(rule, func(t *testing.T) { - tree, err := parser.Parse(rule) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.BuiltinNode{ - Name: tt.fn, - Arguments: []ast.Node{ - &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: tt.wantOp, - Left: &ast.BinaryNode{ - Operator: "and", - Left: &ast.BinaryNode{ - Operator: ">", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Age"}, - }, - Right: &ast.IntegerNode{Value: 18}, - }, - Right: &ast.BinaryNode{ - Operator: "!=", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Name"}, - }, - Right: &ast.StringNode{Value: "Bob"}, - }, - }, - Right: &ast.BinaryNode{ - Operator: "<", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Age"}, - }, - Right: &ast.IntegerNode{Value: 30}, - }, - }, - }, - }, - } - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) - }) - } -} - -func TestOptimize_predicate_combination_nested(t *testing.T) { - tree, err := parser.Parse(`any(users, {all(.Friends, {.Age == 18 })}) && any(users, {all(.Friends, {.Name != "Bob" })})`) - require.NoError(t, err) - - err = optimizer.Optimize(&tree.Node, nil) - require.NoError(t, err) - - expected := &ast.BuiltinNode{ - Name: "any", - Arguments: []ast.Node{ - &ast.IdentifierNode{Value: "users"}, - &ast.ClosureNode{ - Node: &ast.BuiltinNode{ - Name: "all", - Arguments: []ast.Node{ - &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Friends"}, - }, - &ast.ClosureNode{ - Node: &ast.BinaryNode{ - Operator: "&&", - Left: &ast.BinaryNode{ - Operator: "==", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Age"}, - }, - Right: &ast.IntegerNode{Value: 18}, - }, - Right: &ast.BinaryNode{ - Operator: "!=", - Left: &ast.MemberNode{ - Node: &ast.PointerNode{}, - Property: &ast.StringNode{Value: "Name"}, - }, - Right: &ast.StringNode{Value: "Bob"}, - }, - }, - }, - }, - }, - }, - }, - } - - assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) -} diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go deleted file mode 100644 index 2733781df..000000000 --- a/optimizer/predicate_combination.go +++ /dev/null @@ -1,51 +0,0 @@ -package optimizer - -import ( - . "github.com/expr-lang/expr/ast" - "github.com/expr-lang/expr/parser/operator" -) - -type predicateCombination struct{} - -func (v *predicateCombination) Visit(node *Node) { - if op, ok := (*node).(*BinaryNode); ok && operator.IsBoolean(op.Operator) { - if left, ok := op.Left.(*BuiltinNode); ok { - if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { - if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { - if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { - closure := &ClosureNode{ - Node: &BinaryNode{ - Operator: combinedOp, - Left: left.Arguments[1].(*ClosureNode).Node, - Right: right.Arguments[1].(*ClosureNode).Node, - }, - } - v.Visit(&closure.Node) - Patch(node, &BuiltinNode{ - Name: left.Name, - Arguments: []Node{ - left.Arguments[0], - closure, - }, - }) - } - } - } - } - } -} - -func combinedOperator(fn, op string) (string, bool) { - switch fn { - case "all", "any": - return op, true - case "one", "none": - switch op { - case "and": - return "or", true - case "&&": - return "||", true - } - } - return "", false -} From 55dc4e80bc10639d13f4739ea8a767edb8b357f4 Mon Sep 17 00:00:00 2001 From: Sergey Date: Fri, 12 Apr 2024 23:49:41 +0700 Subject: [PATCH 15/16] Optimize boolean operations between all, any, none functions (#626) --- expr_test.go | 189 +++++++++++++++++++++++++++++ optimizer/optimizer.go | 1 + optimizer/optimizer_test.go | 116 ++++++++++++++++++ optimizer/predicate_combination.go | 61 ++++++++++ 4 files changed, 367 insertions(+) create mode 100644 optimizer/predicate_combination.go diff --git a/expr_test.go b/expr_test.go index 790fdd5d9..ac8eecf48 100644 --- a/expr_test.go +++ b/expr_test.go @@ -901,18 +901,147 @@ func TestExpr(t *testing.T) { `all(1..3, {# > 0})`, true, }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 4})`, + false, + }, + { + `all(1..3, {# > 0}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 2}) && all(1..3, {# < 2})`, + false, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# > 0}) || all(1..3, {# != 2})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# < 4})`, + true, + }, + { + `all(1..3, {# != 3}) || all(1..3, {# != 2})`, + false, + }, { `none(1..3, {# == 0})`, true, }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 4})`, + false, + }, + { + `none(1..3, {# == 1}) && none(1..3, {# == 3})`, + false, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 0}) || none(1..3, {# == 3})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 4})`, + true, + }, + { + `none(1..3, {# == 1}) || none(1..3, {# == 3})`, + false, + }, { `any([1,1,0,1], {# == 0})`, true, }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 2})`, + false, + }, + { + `any(1..3, {# == 1}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 0}) && any(1..3, {# == 4})`, + false, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 2})`, + true, + }, + { + `any(1..3, {# == 1}) || any(1..3, {# == 4})`, + true, + }, + { + `any(1..3, {# == 0}) || any(1..3, {# == 4})`, + false, + }, { `one([1,1,0,1], {# == 0}) and not one([1,0,0,1], {# == 0})`, true, }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, + false, + }, + { + `one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, + false, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, + true, + }, + { + `one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, + true, + }, + { + `one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, + false, + }, + { `count(1..30, {# % 3 == 0})`, 10, @@ -2525,6 +2654,66 @@ func TestOperatorDependsOnEnv(t *testing.T) { assert.Equal(t, 42, out) } +func TestIssue624(t *testing.T) { + type tag struct { + Name string + } + + type item struct { + Tags []tag + } + + i := item{ + Tags: []tag{ + {Name: "one"}, + {Name: "two"}, + }, + } + + rule := `[ +true && true, +one(Tags, .Name in ["one"]), +one(Tags, .Name in ["two"]), +one(Tags, .Name in ["one"]) && one(Tags, .Name in ["two"]) +]` + resp, err := expr.Eval(rule, i) + require.NoError(t, err) + require.Equal(t, []interface{}{true, true, true, true}, resp) +} + +func TestPredicateCombination(t *testing.T) { + tests := []struct { + code1 string + code2 string + }{ + {"all(1..3, {# > 0}) && all(1..3, {# < 4})", "all(1..3, {# > 0 && # < 4})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 4})", "all(1..3, {# > 1 && # < 4})"}, + {"all(1..3, {# > 0}) && all(1..3, {# < 2})", "all(1..3, {# > 0 && # < 2})"}, + {"all(1..3, {# > 1}) && all(1..3, {# < 2})", "all(1..3, {# > 1 && # < 2})"}, + + {"any(1..3, {# > 0}) || any(1..3, {# < 4})", "any(1..3, {# > 0 || # < 4})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 4})", "any(1..3, {# > 1 || # < 4})"}, + {"any(1..3, {# > 0}) || any(1..3, {# < 2})", "any(1..3, {# > 0 || # < 2})"}, + {"any(1..3, {# > 1}) || any(1..3, {# < 2})", "any(1..3, {# > 1 || # < 2})"}, + + {"none(1..3, {# > 0}) && none(1..3, {# < 4})", "none(1..3, {# > 0 || # < 4})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 4})", "none(1..3, {# > 1 || # < 4})"}, + {"none(1..3, {# > 0}) && none(1..3, {# < 2})", "none(1..3, {# > 0 || # < 2})"}, + {"none(1..3, {# > 1}) && none(1..3, {# < 2})", "none(1..3, {# > 1 || # < 2})"}, + } + for _, tt := range tests { + t.Run(tt.code1, func(t *testing.T) { + out1, err := expr.Eval(tt.code1, nil) + require.NoError(t, err) + + out2, err := expr.Eval(tt.code2, nil) + require.NoError(t, err) + + require.Equal(t, out1, out2) + }) + } +} + func TestArrayComparison(t *testing.T) { tests := []struct { env any diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index a9c0fa3d3..6d1fb0b54 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -36,5 +36,6 @@ func Optimize(node *Node, config *conf.Config) error { Walk(node, &filterLen{}) Walk(node, &filterLast{}) Walk(node, &filterFirst{}) + Walk(node, &predicateCombination{}) return nil } diff --git a/optimizer/optimizer_test.go b/optimizer/optimizer_test.go index e45de763b..316b17182 100644 --- a/optimizer/optimizer_test.go +++ b/optimizer/optimizer_test.go @@ -1,6 +1,7 @@ package optimizer_test import ( + "fmt" "reflect" "strings" "testing" @@ -339,3 +340,118 @@ func TestOptimize_filter_map_first(t *testing.T) { assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) } + +func TestOptimize_predicate_combination(t *testing.T) { + tests := []struct { + op string + fn string + wantOp string + }{ + {"and", "all", "and"}, + {"&&", "all", "&&"}, + {"or", "any", "or"}, + {"||", "any", "||"}, + {"and", "none", "or"}, + {"&&", "none", "||"}, + } + + for _, tt := range tests { + rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn) + t.Run(rule, func(t *testing.T) { + tree, err := parser.Parse(rule) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: tt.fn, + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: tt.wantOp, + Left: &ast.BinaryNode{ + Operator: "and", + Left: &ast.BinaryNode{ + Operator: ">", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + Right: &ast.BinaryNode{ + Operator: "<", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 30}, + }, + }, + }, + }, + } + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) + }) + } +} + +func TestOptimize_predicate_combination_nested(t *testing.T) { + tree, err := parser.Parse(`all(users, {all(.Friends, {.Age == 18 })}) && all(users, {all(.Friends, {.Name != "Bob" })})`) + require.NoError(t, err) + + err = optimizer.Optimize(&tree.Node, nil) + require.NoError(t, err) + + expected := &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.IdentifierNode{Value: "users"}, + &ast.ClosureNode{ + Node: &ast.BuiltinNode{ + Name: "all", + Arguments: []ast.Node{ + &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Friends"}, + }, + &ast.ClosureNode{ + Node: &ast.BinaryNode{ + Operator: "&&", + Left: &ast.BinaryNode{ + Operator: "==", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Age"}, + }, + Right: &ast.IntegerNode{Value: 18}, + }, + Right: &ast.BinaryNode{ + Operator: "!=", + Left: &ast.MemberNode{ + Node: &ast.PointerNode{}, + Property: &ast.StringNode{Value: "Name"}, + }, + Right: &ast.StringNode{Value: "Bob"}, + }, + }, + }, + }, + }, + }, + }, + } + + assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node)) +} diff --git a/optimizer/predicate_combination.go b/optimizer/predicate_combination.go new file mode 100644 index 000000000..6e8a7f7cf --- /dev/null +++ b/optimizer/predicate_combination.go @@ -0,0 +1,61 @@ +package optimizer + +import ( + . "github.com/expr-lang/expr/ast" + "github.com/expr-lang/expr/parser/operator" +) + +/* +predicateCombination is a visitor that combines multiple predicate calls into a single call. +For example, the following expression: + + all(x, x > 1) && all(x, x < 10) -> all(x, x > 1 && x < 10) + any(x, x > 1) || any(x, x < 10) -> any(x, x > 1 || x < 10) + none(x, x > 1) && none(x, x < 10) -> none(x, x > 1 || x < 10) +*/ +type predicateCombination struct{} + +func (v *predicateCombination) Visit(node *Node) { + if op, ok := (*node).(*BinaryNode); ok && operator.IsBoolean(op.Operator) { + if left, ok := op.Left.(*BuiltinNode); ok { + if combinedOp, ok := combinedOperator(left.Name, op.Operator); ok { + if right, ok := op.Right.(*BuiltinNode); ok && right.Name == left.Name { + if left.Arguments[0].Type() == right.Arguments[0].Type() && left.Arguments[0].String() == right.Arguments[0].String() { + closure := &ClosureNode{ + Node: &BinaryNode{ + Operator: combinedOp, + Left: left.Arguments[1].(*ClosureNode).Node, + Right: right.Arguments[1].(*ClosureNode).Node, + }, + } + v.Visit(&closure.Node) + Patch(node, &BuiltinNode{ + Name: left.Name, + Arguments: []Node{ + left.Arguments[0], + closure, + }, + }) + } + } + } + } + } +} + +func combinedOperator(fn, op string) (string, bool) { + switch { + case fn == "all" && (op == "and" || op == "&&"): + return op, true + case fn == "any" && (op == "or" || op == "||"): + return op, true + case fn == "none" && (op == "and" || op == "&&"): + switch op { + case "and": + return "or", true + case "&&": + return "||", true + } + } + return "", false +} From 524efca4b271a8ac6d62f6271c885f4e3e91002a Mon Sep 17 00:00:00 2001 From: Daenney Date: Fri, 12 Apr 2024 22:54:01 +0200 Subject: [PATCH 16/16] Make WithContext work for methods on env struct (#602) This makes the WithContext patcher work when passing in a struct for an env that has methods that take a context. This is a bit fiddly because we can't quite detect the difference of function vs. method on a struct, so we have to check both the first and the second param. Since it's highly unusual to pass the context as anything other than the first parameter, this should work out just fine in practice. Fixes #600. --- patcher/with_context.go | 15 +++++++++++---- patcher/with_context_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/patcher/with_context.go b/patcher/with_context.go index 55b604261..f9861a2c2 100644 --- a/patcher/with_context.go +++ b/patcher/with_context.go @@ -22,11 +22,18 @@ func (w WithContext) Visit(node *ast.Node) { if fn.Kind() != reflect.Func { return } - if fn.NumIn() == 0 { - return - } - if fn.In(0).String() != "context.Context" { + switch fn.NumIn() { + case 0: return + case 1: + if fn.In(0).String() != "context.Context" { + return + } + default: + if fn.In(0).String() != "context.Context" && + fn.In(1).String() != "context.Context" { + return + } } ast.Patch(node, &ast.CallNode{ Callee: call.Callee, diff --git a/patcher/with_context_test.go b/patcher/with_context_test.go index afad4e6f0..5ce64191f 100644 --- a/patcher/with_context_test.go +++ b/patcher/with_context_test.go @@ -62,6 +62,30 @@ func TestWithContext_with_env_Function(t *testing.T) { require.Equal(t, 42, output) } +type testEnvContext struct { + Context context.Context `expr:"ctx"` +} + +func (testEnvContext) Fn(ctx context.Context, a int) int { + return ctx.Value("value").(int) + a +} + +func TestWithContext_env_struct(t *testing.T) { + withContext := patcher.WithContext{Name: "ctx"} + + program, err := expr.Compile(`Fn(40)`, expr.Env(testEnvContext{}), expr.Patch(withContext)) + require.NoError(t, err) + + ctx := context.WithValue(context.Background(), "value", 2) + env := testEnvContext{ + Context: ctx, + } + + output, err := expr.Run(program, env) + require.NoError(t, err) + require.Equal(t, 42, output) +} + type TestFoo struct { contextValue int }