diff --git a/examle_test.go b/examle_test.go index b650d6c..0eb1dce 100644 --- a/examle_test.go +++ b/examle_test.go @@ -62,3 +62,15 @@ func ExampleConcurrency() { // Output: // wait 2s } + +func ExampleDesired() { + g, ctx, _ := racegroup.WithContext(context.Background(), racegroup.Desired(2)) + g.Go(wait(ctx, 3*time.Second)) + g.Go(wait(ctx, 2*time.Second)) + g.Go(wait(ctx, 1*time.Second)) + g.Wait() + + // Output: + // wait 1s + // wait 2s +} diff --git a/option.go b/option.go index 7aae309..f5ea90a 100644 --- a/option.go +++ b/option.go @@ -23,3 +23,14 @@ func Concurrency(i int) Option { return nil } } + +// Desired returns an Option that sets number of desired completed tasks. +func Desired(i int) Option { + return func(g *Group) error { + if i < 1 { + return errors.New("desired option must be greater than zero") + } + g.desired = int64(i) + return nil + } +} diff --git a/racegroup.go b/racegroup.go index 3020c56..d9681c8 100644 --- a/racegroup.go +++ b/racegroup.go @@ -5,6 +5,7 @@ package racegroup import ( "context" "sync" + "sync/atomic" ) // A Group is a collection of goroutines working on subtasks. @@ -14,12 +15,14 @@ type Group struct { errHandler func(error) semaphore chan struct{} + desired int64 + completed int64 } // WithContext returns a new Group and an associated Context derived from ctx. func WithContext(ctx context.Context, opts ...Option) (*Group, context.Context, error) { ctx, cancel := context.WithCancel(ctx) - g := &Group{cancel: cancel} + g := &Group{cancel: cancel, desired: 1} for _, opt := range opts { if err := opt(g); err != nil { return nil, nil, err @@ -38,7 +41,8 @@ func (g *Group) Wait() { // Go calls the given function in a new goroutine. // -// The first call to return a nil error cancels the group. +// If more than or equal to desired count subtasks are completed, +// cancels the group. func (g *Group) Go(f func() error) { g.wg.Add(1) if g.semaphore != nil { @@ -58,7 +62,9 @@ func (g *Group) Go(f func() error) { g.errHandler(err) } } else { - g.cancel() + if atomic.AddInt64(&g.completed, 1) >= g.desired { + g.cancel() + } } }() }