diff --git a/checker/checker.go b/checker/checker.go index b46178d4..49cf5dee 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -481,6 +481,12 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) { } return base.Elem(), info{} + case reflect.String: + if !isInteger(prop) && !isAny(prop) { + return v.error(node.Property, "string elements can only be selected using an integer (got %v)", prop) + } + return stringType, info{} + case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value diff --git a/compiler/compiler.go b/compiler/compiler.go index a4f189e6..453d0b67 100644 --- a/compiler/compiler.go +++ b/compiler/compiler.go @@ -629,7 +629,11 @@ func (c *compiler) MemberNode(node *ast.MemberNode) { if op == OpFetch { c.compile(node.Property) - c.emit(OpFetch) + if node.Optional { + c.emit(OpOptionalFetch) + } else { + c.emit(op) + } } else { c.emitLocation(node.Location(), op, c.addConstant( &runtime.Field{Index: index, Path: path}, diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index 741142a7..7864d636 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -515,6 +515,71 @@ func TestCompile_optimizes_jumps(t *testing.T) { {vm.OpNil, 0}, }, }, + { + `let m = {"a": {"b": {"c": 1}}}; m.a.b.c ?? 'nil coalescing'`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpPush, 2}, + {vm.OpPush, 3}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpStore, 0}, + {vm.OpLoadVar, 0}, + {vm.OpJumpIfNil, 8}, + {vm.OpPush, 0}, + {vm.OpOptionalFetch, 0}, + {vm.OpJumpIfNil, 5}, + {vm.OpPush, 1}, + {vm.OpOptionalFetch, 0}, + {vm.OpJumpIfNil, 2}, + {vm.OpPush, 2}, + {vm.OpOptionalFetch, 0}, + {vm.OpDeref, 0}, + {vm.OpJumpIfNotNil, 2}, + {vm.OpPop, 0}, + {vm.OpPush, 4}, + }, + }, + { + `let m = [{"a": {"b": {"c": 1}}}]; m[5].a.b.c ?? 'nil coalescing'`, + []op{ + {vm.OpPush, 0}, + {vm.OpPush, 1}, + {vm.OpPush, 2}, + {vm.OpPush, 3}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpMap, 0}, + {vm.OpPush, 3}, + {vm.OpArray, 0}, + {vm.OpStore, 0}, + {vm.OpLoadVar, 0}, + {vm.OpJumpIfNil, 11}, + {vm.OpPush, 4}, + {vm.OpOptionalFetch, 0}, + {vm.OpJumpIfNil, 8}, + {vm.OpPush, 0}, + {vm.OpOptionalFetch, 0}, + {vm.OpJumpIfNil, 5}, + {vm.OpPush, 1}, + {vm.OpOptionalFetch, 0}, + {vm.OpJumpIfNil, 2}, + {vm.OpPush, 2}, + {vm.OpOptionalFetch, 0}, + {vm.OpDeref, 0}, + {vm.OpJumpIfNotNil, 2}, + {vm.OpPop, 0}, + {vm.OpPush, 5}, + }, + }, { `let m = {"a": {"b": {"c": 1}}}; m?.a?.b?.c`, []op{ @@ -532,13 +597,24 @@ func TestCompile_optimizes_jumps(t *testing.T) { {vm.OpLoadVar, 0}, {vm.OpJumpIfNil, 8}, {vm.OpPush, 0}, - {vm.OpFetch, 0}, + {vm.OpOptionalFetch, 0}, {vm.OpJumpIfNil, 5}, {vm.OpPush, 1}, - {vm.OpFetch, 0}, + {vm.OpOptionalFetch, 0}, {vm.OpJumpIfNil, 2}, {vm.OpPush, 2}, - {vm.OpFetch, 0}, + {vm.OpOptionalFetch, 0}, + }, + }, + { + `let m = [1, 2, 3]; m?.[5]`, + []op{ + {vm.OpPush, 0}, + {vm.OpStore, 0}, + {vm.OpLoadVar, 0}, + {vm.OpJumpIfNil, 2}, + {vm.OpPush, 1}, + {vm.OpOptionalFetch, 0}, }, }, } diff --git a/expr_test.go b/expr_test.go index 9d9d88af..1102bc02 100644 --- a/expr_test.go +++ b/expr_test.go @@ -1273,6 +1273,18 @@ func TestExpr(t *testing.T) { `1 < 2 < 3 == true`, true, }, + { + `[1, 2, 3]?.[5]`, + nil, + }, + { + `'string'?.[5]`, + "g", + }, + { + `'string'?.[7]`, + nil, + }, } for _, tt := range tests { @@ -1967,6 +1979,24 @@ func TestRun_NilCoalescingOperator(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "default", out) }) + + t.Run("default without chain", func(t *testing.T) { + p, err := expr.Compile(`foo.foo.bar ?? "default"`, expr.Env(env)) + assert.NoError(t, err) + + out, err := expr.Run(p, map[string]any{}) + assert.NoError(t, err) + assert.Equal(t, "default", out) + }) + + t.Run("array default without chain", func(t *testing.T) { + p, err := expr.Compile(`foo.foo[10].bar ?? "default"`, expr.Env(env)) + assert.NoError(t, err) + + out, err := expr.Run(p, map[string]any{}) + assert.NoError(t, err) + assert.Equal(t, "default", out) + }) } func TestEval_nil_in_maps(t *testing.T) { diff --git a/parser/parser.go b/parser/parser.go index 9114bc0c..f6c955db 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -176,6 +176,10 @@ func (p *parser) parseExpression(precedence int) Node { nodeRight = p.parseExpression(op.Precedence) } + if opToken.Value == "??" { + nodeLeft = p.flattenChainedOptionals(nodeLeft) + } + nodeLeft = &BinaryNode{ Operator: opToken.Value, Left: nodeLeft, @@ -579,65 +583,66 @@ end: } func (p *parser) parsePostfixExpression(node Node) Node { - postfixToken := p.current - for (postfixToken.Is(Operator) || postfixToken.Is(Bracket)) && p.err == nil { - optional := postfixToken.Value == "?." - parseToken: - if postfixToken.Value == "." || postfixToken.Value == "?." { - p.next() - - propertyToken := p.current - if optional && propertyToken.Is(Bracket, "[") { - postfixToken = propertyToken - goto parseToken - } + for postfixToken := p.current; (postfixToken.Is(Operator) || postfixToken.Is(Bracket)) && p.err == nil; postfixToken = p.current { + switch postfixToken.Value { + case ".", "?.": p.next() - if propertyToken.Kind != Identifier && - // Operators like "not" and "matches" are valid methods or property names. - (propertyToken.Kind != Operator || !utils.IsValidIdentifier(propertyToken.Value)) { - p.error("expected name") - } - - property := &StringNode{Value: propertyToken.Value} - property.SetLocation(propertyToken.Location) - chainNode, isChain := node.(*ChainNode) - optional := postfixToken.Value == "?." - if isChain { node = chainNode.Node } + propertyToken := p.current + isOptional := postfixToken.Value == "?." - memberNode := &MemberNode{ - Node: node, - Property: property, - Optional: optional, - } - memberNode.SetLocation(propertyToken.Location) - - if p.current.Is(Bracket, "(") { - memberNode.Method = true - node = &CallNode{ - Callee: memberNode, - Arguments: p.parseArguments([]Node{}), + if isOptional && p.current.Is(Bracket, "[") { + p.next() + node = &MemberNode{ + Node: node, + Property: p.parseExpression(0), + Optional: true, } node.SetLocation(propertyToken.Location) + p.expect(Bracket, "]") } else { - node = memberNode - } + p.next() + + if propertyToken.Kind != Identifier && + // Operators like "not" and "matches" are valid methods or property names. + (propertyToken.Kind != Operator || !utils.IsValidIdentifier(propertyToken.Value)) { + p.error("expected name") + } - if isChain || optional { + property := &StringNode{Value: propertyToken.Value} + property.SetLocation(propertyToken.Location) + + memberNode := &MemberNode{ + Node: node, + Property: property, + Optional: isOptional, + } + memberNode.SetLocation(propertyToken.Location) + + if p.current.Is(Bracket, "(") { + memberNode.Method = true + node = &CallNode{ + Callee: memberNode, + Arguments: p.parseArguments([]Node{}), + } + node.SetLocation(propertyToken.Location) + } else { + node = memberNode + } + } + if isChain || isOptional { node = &ChainNode{Node: node} } - - } else if postfixToken.Value == "[" { + case "[": p.next() - var from, to Node - if p.current.Is(Operator, ":") { // slice without from [:1] p.next() + var to Node if !p.current.Is(Bracket, "]") { // slice without from and to [:] to = p.parseExpression(0) } @@ -646,16 +651,13 @@ func (p *parser) parsePostfixExpression(node Node) Node { Node: node, To: to, } - node.SetLocation(postfixToken.Location) - p.expect(Bracket, "]") - } else { - - from = p.parseExpression(0) + from := p.parseExpression(0) if p.current.Is(Operator, ":") { p.next() + var to Node if !p.current.Is(Bracket, "]") { // slice without to [1:] to = p.parseExpression(0) } @@ -665,28 +667,20 @@ func (p *parser) parsePostfixExpression(node Node) Node { From: from, To: to, } - node.SetLocation(postfixToken.Location) - p.expect(Bracket, "]") - } else { // Slice operator [:] was not found, // it should be just an index node. node = &MemberNode{ Node: node, Property: from, - Optional: optional, } - node.SetLocation(postfixToken.Location) - if optional { - node = &ChainNode{Node: node} - } - p.expect(Bracket, "]") } } - } else { - break + node.SetLocation(postfixToken.Location) + p.expect(Bracket, "]") + default: + return node } - postfixToken = p.current } return node } @@ -721,3 +715,17 @@ func (p *parser) parseComparison(left Node, token Token, precedence int) Node { } return rootNode } + +func (p *parser) flattenChainedOptionals(nodeLeft Node) Node { + switch node := nodeLeft.(type) { + case *MemberNode: + node.Optional = true + node.Node = p.flattenChainedOptionals(node.Node) + return &ChainNode{Node: node} + case *ChainNode: + node.Node = p.flattenChainedOptionals(node.Node) + return node + default: + return nodeLeft + } +} diff --git a/parser/parser_test.go b/parser/parser_test.go index 9225e102..1020c9a2 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -449,6 +449,25 @@ world`}, Right: &IdentifierNode{Value: "bar"}}, Right: &IdentifierNode{Value: "baz"}}, }, + { + "foo.bar.baz ?? 'nil'", + &BinaryNode{Operator: "??", + Left: &ChainNode{ + Node: &MemberNode{ + Node: &ChainNode{ + Node: &MemberNode{ + Node: &IdentifierNode{Value: "foo"}, + Property: &StringNode{Value: "bar"}, + Optional: true, + }, + }, + Property: &StringNode{Value: "baz"}, + Optional: true, + }, + }, + Right: &StringNode{Value: "nil"}, + }, + }, { "foo ?? (bar || baz)", &BinaryNode{Operator: "??", diff --git a/vm/opcodes.go b/vm/opcodes.go index 0417dab6..f03ced15 100644 --- a/vm/opcodes.go +++ b/vm/opcodes.go @@ -17,6 +17,7 @@ const ( OpLoadEnv OpFetch OpFetchField + OpOptionalFetch OpMethod OpTrue OpFalse diff --git a/vm/program.go b/vm/program.go index 4a878267..ed2c9d14 100644 --- a/vm/program.go +++ b/vm/program.go @@ -164,6 +164,9 @@ func (program *Program) DisassembleWriter(w io.Writer) { case OpFetch: code("OpFetch") + case OpOptionalFetch: + code("OpOptionalFetch") + case OpFetchField: constant("OpFetchField") diff --git a/vm/runtime/runtime.go b/vm/runtime/runtime.go index 7da1320d..4371f0f1 100644 --- a/vm/runtime/runtime.go +++ b/vm/runtime/runtime.go @@ -11,66 +11,80 @@ import ( ) func Fetch(from, i any) any { + return fetch(from, i, true) +} + +func OptionalFetch(from, i any) any { + return fetch(from, i, false) +} + +func fetch(from, i any, strict bool) any { v := reflect.ValueOf(from) - if v.Kind() == reflect.Invalid { - panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) - } + if v.Kind() != reflect.Invalid { - // Methods can be defined on any type. - if v.NumMethod() > 0 { - if methodName, ok := i.(string); ok { - method := v.MethodByName(methodName) - if method.IsValid() { - return method.Interface() + // Methods can be defined on any type. + if v.NumMethod() > 0 { + if methodName, ok := i.(string); ok { + method := v.MethodByName(methodName) + if method.IsValid() { + return method.Interface() + } } } - } - - // Structs, maps, and slices can be access through a pointer or through - // a value, when they are accessed through a pointer we don't want to - // copy them to a value. - // De-reference everything if necessary (interface and pointers) - v = deref.Value(v) - switch v.Kind() { - case reflect.Array, reflect.Slice, reflect.String: - index := ToInt(i) - if index < 0 { - index = v.Len() + index - } - value := v.Index(index) - if value.IsValid() { - return value.Interface() - } + // Structs, maps, and slices can be access through a pointer or through + // a value, when they are accessed through a pointer we don't want to + // copy them to a value. + // De-reference everything if necessary (interface and pointers) + v = deref.Value(v) + kind := v.Kind() + + switch kind { + case reflect.Array, reflect.Slice, reflect.String: + index := ToInt(i) + size := v.Len() + if index < 0 { + index += size + } + if (index >= 0 && index < size) || strict { + if kind == reflect.String { + return string(v.Index(index).Interface().(uint8)) + } + return v.Index(index).Interface() + } - case reflect.Map: - var value reflect.Value - if i == nil { - value = v.MapIndex(reflect.Zero(v.Type().Key())) - } else { - value = v.MapIndex(reflect.ValueOf(i)) - } - if value.IsValid() { - return value.Interface() - } else { - elem := reflect.TypeOf(from).Elem() - return reflect.Zero(elem).Interface() - } + case reflect.Map: + var value reflect.Value + if i == nil { + value = v.MapIndex(reflect.Zero(v.Type().Key())) + } else { + value = v.MapIndex(reflect.ValueOf(i)) + } + if value.IsValid() { + return value.Interface() + } else { + elem := reflect.TypeOf(from).Elem() + return reflect.Zero(elem).Interface() + } - case reflect.Struct: - fieldName := i.(string) - value := v.FieldByNameFunc(func(name string) bool { - field, _ := v.Type().FieldByName(name) - if field.Tag.Get("expr") == fieldName { - return true + case reflect.Struct: + fieldName := i.(string) + value := v.FieldByNameFunc(func(name string) bool { + field, _ := v.Type().FieldByName(name) + if field.Tag.Get("expr") == fieldName { + return true + } + return name == fieldName + }) + if value.IsValid() { + return value.Interface() } - return name == fieldName - }) - if value.IsValid() { - return value.Interface() } } - panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) + if strict { + panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) + } + return nil } type Field struct { diff --git a/vm/vm.go b/vm/vm.go index 1e85893b..f8bbf750 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -128,6 +128,11 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) { a := vm.pop() vm.push(runtime.Fetch(a, b)) + case OpOptionalFetch: + b := vm.pop() + a := vm.pop() + vm.push(runtime.OptionalFetch(a, b)) + case OpFetchField: a := vm.pop() vm.push(runtime.FetchField(a, program.Constants[arg].(*runtime.Field)))