diff --git a/doc/spec.md b/doc/spec.md index ba5317a0..8711f0a2 100644 --- a/doc/spec.md +++ b/doc/spec.md @@ -3762,11 +3762,11 @@ x.symmetric_difference([3, 4, 5]) # set([1, 2, 4, 5]) ### set·union -`S.union(iterable)` returns a new set into which have been inserted -all the elements of set S and all the elements of the argument, which -must be iterable. +`S.union(iterable...)` returns a new set into which have been inserted +all the elements of set S and each element of the iterable sequences. -`union` fails if any element of the iterable is not hashable. +`union` fails if any argument is not an iterable sequence, or if any +sequence element is not hashable. ```python x = set([1, 2]) diff --git a/starlark/library.go b/starlark/library.go index 10dfdfd1..6d7ff454 100644 --- a/starlark/library.go +++ b/starlark/library.go @@ -2337,41 +2337,18 @@ func set_symmetric_difference(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) // https://github.com/google/starlark-go/blob/master/doc/spec.md#set·union. func set_union(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { - var iterable Iterable - if err := UnpackPositionalArgs(b.Name(), args, kwargs, 0, &iterable); err != nil { - return nil, err - } - iter := iterable.Iterate() - defer iter.Done() - union, err := b.Receiver().(*Set).Union(iter) - if err != nil { + receiverSet := b.Receiver().(*Set).clone() + if err := setUpdate(receiverSet, args, kwargs); err != nil { return nil, nameErr(b, err) } - return union, nil + return receiverSet, nil } // https://github.com/google/starlark-go/blob/master/doc/spec.md#set·update. func set_update(_ *Thread, b *Builtin, args Tuple, kwargs []Tuple) (Value, error) { - if len(kwargs) > 0 { - return nil, nameErr(b, "update does not accept keyword arguments") - } - - receiverSet := b.Receiver().(*Set) - - for i, arg := range args { - iterable, ok := arg.(Iterable) - if !ok { - return nil, fmt.Errorf("update: argument #%d is not iterable: %s", i+1, arg.Type()) - } - if err := func() error { - iter := iterable.Iterate() - defer iter.Done() - return receiverSet.InsertAll(iter) - }(); err != nil { - return nil, nameErr(b, err) - } + if err := setUpdate(b.Receiver().(*Set), args, kwargs); err != nil { + return nil, nameErr(b, err) } - return None, nil } @@ -2474,6 +2451,28 @@ func updateDict(dict *Dict, updates Tuple, kwargs []Tuple) error { return nil } +func setUpdate(s *Set, args Tuple, kwargs []Tuple) error { + if len(kwargs) > 0 { + return errors.New("does not accept keyword arguments") + } + + for i, arg := range args { + iterable, ok := arg.(Iterable) + if !ok { + return fmt.Errorf("argument #%d is not iterable: %s", i+1, arg.Type()) + } + if err := func() error { + iter := iterable.Iterate() + defer iter.Done() + return s.InsertAll(iter) + }(); err != nil { + return err + } + } + + return nil +} + // nameErr returns an error message of the form "name: msg" // where name is b.Name() and msg is a string or error. func nameErr(b *Builtin, msg interface{}) error { diff --git a/starlark/testdata/set.star b/starlark/testdata/set.star index 7a831129..9aa722b1 100644 --- a/starlark/testdata/set.star +++ b/starlark/testdata/set.star @@ -59,14 +59,23 @@ assert.eq(list(set("a".elems()).union("b".elems())), ["a", "b"]) assert.eq(list(set("ab".elems()).union("bc".elems())), ["a", "b", "c"]) assert.eq(set().union([]), set()) assert.eq(type(x.union(y)), "set") +assert.eq(list(x.union()), [1, 2, 3]) assert.eq(list(x.union(y)), [1, 2, 3, 4, 5]) +assert.eq(list(x.union(y, [6, 7])), [1, 2, 3, 4, 5, 6, 7]) assert.eq(list(x.union([5, 1])), [1, 2, 3, 5]) assert.eq(list(x.union((6, 5, 4))), [1, 2, 3, 6, 5, 4]) assert.fails(lambda : x.union([1, 2, {}]), "unhashable type: dict") +assert.fails(lambda : x.union(1, 2, 3), "argument #1 is not iterable: int") # set.update (allows any iterable for the right operand) # The update function will mutate the set so the tests below are # scoped using a function. + +def test_update_return_value(): + assert.eq(set(x).update(y), None) + +test_update_return_value() + def test_update_elems_singular(): s = set("a".elems()) s.update("b".elems()) @@ -130,7 +139,7 @@ test_update_non_iterable() def test_update_kwargs(): s = set(x) - assert.fails(lambda: x.update(gee = [3, 4]), "update: update does not accept keyword arguments") + assert.fails(lambda: x.update(gee = [3, 4]), "update: does not accept keyword arguments") test_update_kwargs()