diff --git a/sync/threadgroup.go b/sync/threadgroup.go index 32abb43925..8a661fff30 100644 --- a/sync/threadgroup.go +++ b/sync/threadgroup.go @@ -26,8 +26,9 @@ type ThreadGroup struct { afterStopFns []func() once sync.Once - mu sync.Mutex stopChan chan struct{} + bmu sync.Mutex // Ensures blocking between calls to 'Add', 'Flush', and 'Stop' + mu sync.Mutex // Protects the 'onStopFns' and 'afterStopFns' variable wg sync.WaitGroup } @@ -49,8 +50,8 @@ func (tg *ThreadGroup) isStopped() bool { // Add increments the thread group counter. func (tg *ThreadGroup) Add() error { - tg.mu.Lock() - defer tg.mu.Unlock() + tg.bmu.Lock() + defer tg.bmu.Unlock() if tg.isStopped() { return ErrStopped @@ -105,8 +106,8 @@ func (tg *ThreadGroup) Done() { // called 'tg.Done'. This in effect 'flushes' the module, letting it complete // any tasks that are open before taking on new ones. func (tg *ThreadGroup) Flush() error { - tg.mu.Lock() - defer tg.mu.Unlock() + tg.bmu.Lock() + defer tg.bmu.Unlock() if tg.isStopped() { return ErrStopped @@ -121,27 +122,31 @@ func (tg *ThreadGroup) Flush() error { // order. After Stop is called, most actions will return ErrStopped. func (tg *ThreadGroup) Stop() error { // Establish that Stop has been called. - tg.mu.Lock() - defer tg.mu.Unlock() + tg.bmu.Lock() + defer tg.bmu.Unlock() if tg.isStopped() { return ErrStopped } close(tg.stopChan) + tg.mu.Lock() for i := len(tg.onStopFns) - 1; i >= 0; i-- { tg.onStopFns[i]() } tg.onStopFns = nil + tg.mu.Unlock() tg.wg.Wait() // After waiting for all resources to release the thread group, iterate // through the stop functions and call them in reverse oreder. + tg.mu.Lock() for i := len(tg.afterStopFns) - 1; i >= 0; i-- { tg.afterStopFns[i]() } tg.afterStopFns = nil + tg.mu.Unlock() return nil } diff --git a/sync/threadgroup_test.go b/sync/threadgroup_test.go index 1dd99cb7e5..478d5a9730 100644 --- a/sync/threadgroup_test.go +++ b/sync/threadgroup_test.go @@ -416,6 +416,57 @@ func TestThreadGroupSiaExample(t *testing.T) { } } +// TestAddOnStop checks that you can safely call OnStop from under the +// protection of an Add call. +func TestAddOnStop(t *testing.T) { + if testing.Short() { + t.SkipNow() + } + t.Parallel() + + var tg ThreadGroup + var data int + addChan := make(chan struct{}) + stopChan := make(chan struct{}) + tg.OnStop(func() { + close(stopChan) + }) + go func() { + err := tg.Add() + if err != nil { + t.Fatal(err) + } + close(addChan) + + // Wait for the call to 'Stop' to be called in the parent thread, and + // then queue a bunch of 'OnStop' and 'AfterStop' functions before + // calling 'Done'. + <-stopChan + for i := 0; i < 10; i++ { + tg.OnStop(func() { + data++ + }) + tg.AfterStop(func() { + data++ + }) + } + tg.Done() + }() + + // Wait for 'Add' to be called in the above thread, to guarantee that + // OnStop and AfterStop will be called after 'Add' and 'Stop' have been + // called together. + <-addChan + err := tg.Stop() + if err != nil { + t.Fatal(err) + } + + if data != 20 { + t.Error("20 calls were made to increment data, but value is", data) + } +} + // BenchmarkThreadGroup times how long it takes to add a ton of threads and // trigger goroutines that call Done. func BenchmarkThreadGroup(b *testing.B) {