Skip to content

Commit

Permalink
RSDK-8802: Add StoppableWorkers.StartTimer/NextTick. (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
dgottlieb authored Sep 19, 2024
1 parent 56f9f92 commit b5700ec
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
53 changes: 45 additions & 8 deletions stoppable_workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package utils
import (
"context"
"sync"
"time"
)

// StoppableWorkers is a collection of goroutines that can be stopped at a
Expand Down Expand Up @@ -44,16 +45,49 @@ func NewBackgroundStoppableWorkers(workers ...func(context.Context)) *StoppableW
return sw
}

// Add starts up a goroutine for the passed-in function. Workers:
//
// - MUST respond appropriately to errors on the context parameter.
// - MUST NOT add more workers to the `StoppableWorkers` group to which
// they belong.
// NewStoppableWorkerWithTicker creates a `StoppableWorkers` object with a single worker that gets
// called every `tickRate`. Calls to the input `worker` function are serialized. I.e: a slow "work"
// iteration will just slow down when the next one is called.
func NewStoppableWorkerWithTicker(tickRate time.Duration, workFn func(context.Context)) *StoppableWorkers {
ctx, cancelFunc := context.WithCancel(context.Background())
sw := &StoppableWorkers{ctx: ctx, cancelFunc: cancelFunc}
sw.workers.Add(1)
PanicCapturingGo(func() {
defer sw.workers.Done()

timer := time.NewTicker(tickRate)
defer timer.Stop()
for {
select {
case <-ctx.Done():
return
default:
}

select {
case <-timer.C:
workFn(ctx)
case <-ctx.Done():
return
}
}
})

return sw
}

// Add starts up a goroutine for the passed-in function. Workers must respond appropriately to
// errors on the context parameter.
//
// The worker will not be added if the StoppableWorkers instance has already
// been stopped. Any `panic`s from workers will be `recover`ed and logged.
func (sw *StoppableWorkers) Add(worker func(context.Context)) {
// Read-lock to allow concurrent worker addition. The Stop method will write-lock.
// Acquire the read lock to allow concurrent worker addition. The Stop method will
// write-lock. `Add` is guaranteed to either:
// - Observe the context is canceled -- the worker will not be run, nor will the `workers`
// WaitGroup be incremented
// - Observe the context is not canceled atomically with incrementing the `workers`
// WaitGroup. `Stop` is guaranteed to wait for this new worker to complete before returning.
sw.mu.RLock()
if sw.ctx.Err() != nil {
sw.mu.RUnlock()
Expand All @@ -70,13 +104,16 @@ func (sw *StoppableWorkers) Add(worker func(context.Context)) {

// Stop idempotently shuts down all the goroutines we started up.
func (sw *StoppableWorkers) Stop() {
// Call `cancelFunc` with the write lock that competes with "readers" that can add workers. This
// guarantees `Add` worker calls that start a goroutine have incremented the `workers` WaitGroup
// prior to `Stop` calling `Wait`.
sw.mu.Lock()
defer sw.mu.Unlock()
if sw.ctx.Err() != nil {
return
}

sw.cancelFunc()
sw.mu.Unlock()

sw.workers.Wait()
}

Expand Down
45 changes: 34 additions & 11 deletions stoppable_workers_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package utils_test
package utils

import (
"bytes"
Expand All @@ -7,8 +7,6 @@ import (
"time"

"go.viam.com/test"

"go.viam.com/utils"
)

func TestStoppableWorkers(t *testing.T) {
Expand All @@ -17,7 +15,7 @@ func TestStoppableWorkers(t *testing.T) {
ctx := context.Background()

t.Run("one worker", func(t *testing.T) {
sw := utils.NewStoppableWorkers(ctx)
sw := NewStoppableWorkers(ctx)
sw.Add(normalWorker)
sw.Stop()
ctx := sw.Context()
Expand All @@ -26,15 +24,15 @@ func TestStoppableWorkers(t *testing.T) {
})

t.Run("one worker background constructor", func(t *testing.T) {
sw := utils.NewBackgroundStoppableWorkers(normalWorker)
sw := NewBackgroundStoppableWorkers(normalWorker)
sw.Stop()
ctx := sw.Context()
test.That(t, ctx, test.ShouldNotBeNil)
test.That(t, ctx.Err(), test.ShouldBeError, context.Canceled)
})

t.Run("heavy workers", func(t *testing.T) {
sw := utils.NewStoppableWorkers(ctx)
sw := NewStoppableWorkers(ctx)
sw.Add(heavyWorker)
sw.Add(heavyWorker)
sw.Add(heavyWorker)
Expand Down Expand Up @@ -75,7 +73,7 @@ func TestStoppableWorkers(t *testing.T) {
}
}

sw := utils.NewBackgroundStoppableWorkers(writeWorker, readWorker)
sw := NewBackgroundStoppableWorkers(writeWorker, readWorker)
// Sleep for a second to let concurrent workers do work.
time.Sleep(500 * time.Millisecond)
sw.Stop()
Expand All @@ -84,26 +82,51 @@ func TestStoppableWorkers(t *testing.T) {
})

t.Run("nested workers", func(t *testing.T) {
sw := utils.NewStoppableWorkers(ctx)
sw := NewStoppableWorkers(ctx)
sw.Add(nestedWorkersWorker)
sw.Stop()
})

t.Run("panicking worker", func(t *testing.T) {
sw := utils.NewStoppableWorkers(ctx)
sw := NewStoppableWorkers(ctx)
// Both adding and stopping a panicking worker should cause no `panic`s.
sw.Add(panickingWorker)
sw.Stop()
})

t.Run("already stopped", func(t *testing.T) {
sw := utils.NewStoppableWorkers(ctx)
sw := NewStoppableWorkers(ctx)
sw.Stop()
sw.Add(normalWorker) // adding after Stop should cause no `panic`
sw.Stop() // stopping twice should cause no `panic`
})
}

func TestStoppableWorkersWithTicker(t *testing.T) {
timesCalled := 0
workFn := func(ctx context.Context) {
timesCalled++
select {
case <-time.After(24 * time.Hour):
t.Log("Failed to observe `Stop` call.")
// Realistically, the go test timeout will be hit and not this `FailNow` call.
t.FailNow()
case <-ctx.Done():
return
}
}

// Create a worker with a ticker << the sleep time the test will use. The work function
// increments a counter and hangs. This test will logically assert that:
// - The work function was called exactly once.
// - The work function was passed a context that observed `Stop` was called.
sw := NewStoppableWorkerWithTicker(time.Millisecond, workFn)
time.Sleep(time.Second)
sw.Stop()

test.That(t, timesCalled, test.ShouldEqual, 1)
}

func normalWorker(ctx context.Context) {
for {
select {
Expand Down Expand Up @@ -131,7 +154,7 @@ func heavyWorker(ctx context.Context) {
}

func nestedWorkersWorker(ctx context.Context) {
nestedSW := utils.NewStoppableWorkers(ctx)
nestedSW := NewStoppableWorkers(ctx)
nestedSW.Add(normalWorker)

normalWorker(ctx)
Expand Down

0 comments on commit b5700ec

Please sign in to comment.