Skip to content

Commit

Permalink
separate ThreadGroup locking by function
Browse files Browse the repository at this point in the history
The threadgroup has two separate things that it is protecting when it
locks the mutex. One group of things is to block the execution of other
threads, in the case of 'Add', 'Flush', and 'Stop'. The other group of
things is to protect the ThreadGroup state variables, namely 'onStopFns'
and 'afterStopFns'.

Previously, you would get a deadlock with the following:

tg.Add()
go func {
	tg.Stop()
}
time.Sleep(1 * time.Second)
tg.OnStop()
tg.Done()

The deadlock happens because tg.Stop() is blocking the OnStop call while
it waits for the tg.Add to return, but the tg.Add will not return until
tg.OnStop can get the lock.

I don't see any reason for this behavior to be forbidden. You want to be
able to do something like:

tg.Add()
tg.RegisterRPC()
tg.OnStop(func() {
	tg.UnregisterRPC()
})
tg.Done()

and currently that's not safe code if you are assuming that Stop() could
be called at any time.
  • Loading branch information
DavidVorick committed Jul 17, 2016
1 parent 2856326 commit 3ffc5b3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
19 changes: 12 additions & 7 deletions sync/threadgroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ type ThreadGroup struct {
afterStopFns []func()

once sync.Once
mu sync.Mutex
stopChan chan struct{}
bmu sync.Mutex // Ensures blocking between calls to 'Add', 'Flush', and 'Stop'
mu sync.Mutex // Protects the 'onStopFns' and 'afterStopFns' variable
wg sync.WaitGroup
}

Expand All @@ -49,8 +50,8 @@ func (tg *ThreadGroup) isStopped() bool {

// Add increments the thread group counter.
func (tg *ThreadGroup) Add() error {
tg.mu.Lock()
defer tg.mu.Unlock()
tg.bmu.Lock()
defer tg.bmu.Unlock()

if tg.isStopped() {
return ErrStopped
Expand Down Expand Up @@ -105,8 +106,8 @@ func (tg *ThreadGroup) Done() {
// called 'tg.Done'. This in effect 'flushes' the module, letting it complete
// any tasks that are open before taking on new ones.
func (tg *ThreadGroup) Flush() error {
tg.mu.Lock()
defer tg.mu.Unlock()
tg.bmu.Lock()
defer tg.bmu.Unlock()

if tg.isStopped() {
return ErrStopped
Expand All @@ -121,27 +122,31 @@ func (tg *ThreadGroup) Flush() error {
// order. After Stop is called, most actions will return ErrStopped.
func (tg *ThreadGroup) Stop() error {
// Establish that Stop has been called.
tg.mu.Lock()
defer tg.mu.Unlock()
tg.bmu.Lock()
defer tg.bmu.Unlock()

if tg.isStopped() {
return ErrStopped
}
close(tg.stopChan)

tg.mu.Lock()
for i := len(tg.onStopFns) - 1; i >= 0; i-- {
tg.onStopFns[i]()
}
tg.onStopFns = nil
tg.mu.Unlock()

tg.wg.Wait()

// After waiting for all resources to release the thread group, iterate
// through the stop functions and call them in reverse oreder.
tg.mu.Lock()
for i := len(tg.afterStopFns) - 1; i >= 0; i-- {
tg.afterStopFns[i]()
}
tg.afterStopFns = nil
tg.mu.Unlock()
return nil
}

Expand Down
51 changes: 51 additions & 0 deletions sync/threadgroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,57 @@ func TestThreadGroupSiaExample(t *testing.T) {
}
}

// TestAddOnStop checks that you can safely call OnStop from under the
// protection of an Add call.
func TestAddOnStop(t *testing.T) {
if testing.Short() {
t.SkipNow()
}
t.Parallel()

var tg ThreadGroup
var data int
addChan := make(chan struct{})
stopChan := make(chan struct{})
tg.OnStop(func() {
close(stopChan)
})
go func() {
err := tg.Add()
if err != nil {
t.Fatal(err)
}
close(addChan)

// Wait for the call to 'Stop' to be called in the parent thread, and
// then queue a bunch of 'OnStop' and 'AfterStop' functions before
// calling 'Done'.
<-stopChan
for i := 0; i < 10; i++ {
tg.OnStop(func() {
data++
})
tg.AfterStop(func() {
data++
})
}
tg.Done()
}()

// Wait for 'Add' to be called in the above thread, to guarantee that
// OnStop and AfterStop will be called after 'Add' and 'Stop' have been
// called together.
<-addChan
err := tg.Stop()
if err != nil {
t.Fatal(err)
}

if data != 20 {
t.Error("20 calls were made to increment data, but value is", data)
}
}

// BenchmarkThreadGroup times how long it takes to add a ton of threads and
// trigger goroutines that call Done.
func BenchmarkThreadGroup(b *testing.B) {
Expand Down

0 comments on commit 3ffc5b3

Please sign in to comment.