From b5700ec1cb8a407495248ce2e5ecd7f71fa4a58f Mon Sep 17 00:00:00 2001 From: Dan Gottlieb Date: Thu, 19 Sep 2024 11:20:26 -0400 Subject: [PATCH] RSDK-8802: Add StoppableWorkers.StartTimer/NextTick. (#353) --- stoppable_workers.go | 53 +++++++++++++++++++++++++++++++++------ stoppable_workers_test.go | 45 +++++++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/stoppable_workers.go b/stoppable_workers.go index 992b9040..7d60ac64 100644 --- a/stoppable_workers.go +++ b/stoppable_workers.go @@ -3,6 +3,7 @@ package utils import ( "context" "sync" + "time" ) // StoppableWorkers is a collection of goroutines that can be stopped at a @@ -44,16 +45,49 @@ func NewBackgroundStoppableWorkers(workers ...func(context.Context)) *StoppableW return sw } -// Add starts up a goroutine for the passed-in function. Workers: -// -// - MUST respond appropriately to errors on the context parameter. -// - MUST NOT add more workers to the `StoppableWorkers` group to which -// they belong. +// NewStoppableWorkerWithTicker creates a `StoppableWorkers` object with a single worker that gets +// called every `tickRate`. Calls to the input `worker` function are serialized. I.e: a slow "work" +// iteration will just slow down when the next one is called. +func NewStoppableWorkerWithTicker(tickRate time.Duration, workFn func(context.Context)) *StoppableWorkers { + ctx, cancelFunc := context.WithCancel(context.Background()) + sw := &StoppableWorkers{ctx: ctx, cancelFunc: cancelFunc} + sw.workers.Add(1) + PanicCapturingGo(func() { + defer sw.workers.Done() + + timer := time.NewTicker(tickRate) + defer timer.Stop() + for { + select { + case <-ctx.Done(): + return + default: + } + + select { + case <-timer.C: + workFn(ctx) + case <-ctx.Done(): + return + } + } + }) + + return sw +} + +// Add starts up a goroutine for the passed-in function. Workers must respond appropriately to +// errors on the context parameter. // // The worker will not be added if the StoppableWorkers instance has already // been stopped. Any `panic`s from workers will be `recover`ed and logged. func (sw *StoppableWorkers) Add(worker func(context.Context)) { - // Read-lock to allow concurrent worker addition. The Stop method will write-lock. + // Acquire the read lock to allow concurrent worker addition. The Stop method will + // write-lock. `Add` is guaranteed to either: + // - Observe the context is canceled -- the worker will not be run, nor will the `workers` + // WaitGroup be incremented + // - Observe the context is not canceled atomically with incrementing the `workers` + // WaitGroup. `Stop` is guaranteed to wait for this new worker to complete before returning. sw.mu.RLock() if sw.ctx.Err() != nil { sw.mu.RUnlock() @@ -70,13 +104,16 @@ func (sw *StoppableWorkers) Add(worker func(context.Context)) { // Stop idempotently shuts down all the goroutines we started up. func (sw *StoppableWorkers) Stop() { + // Call `cancelFunc` with the write lock that competes with "readers" that can add workers. This + // guarantees `Add` worker calls that start a goroutine have incremented the `workers` WaitGroup + // prior to `Stop` calling `Wait`. sw.mu.Lock() - defer sw.mu.Unlock() if sw.ctx.Err() != nil { return } - sw.cancelFunc() + sw.mu.Unlock() + sw.workers.Wait() } diff --git a/stoppable_workers_test.go b/stoppable_workers_test.go index d0ceae7f..ecd8aca3 100644 --- a/stoppable_workers_test.go +++ b/stoppable_workers_test.go @@ -1,4 +1,4 @@ -package utils_test +package utils import ( "bytes" @@ -7,8 +7,6 @@ import ( "time" "go.viam.com/test" - - "go.viam.com/utils" ) func TestStoppableWorkers(t *testing.T) { @@ -17,7 +15,7 @@ func TestStoppableWorkers(t *testing.T) { ctx := context.Background() t.Run("one worker", func(t *testing.T) { - sw := utils.NewStoppableWorkers(ctx) + sw := NewStoppableWorkers(ctx) sw.Add(normalWorker) sw.Stop() ctx := sw.Context() @@ -26,7 +24,7 @@ func TestStoppableWorkers(t *testing.T) { }) t.Run("one worker background constructor", func(t *testing.T) { - sw := utils.NewBackgroundStoppableWorkers(normalWorker) + sw := NewBackgroundStoppableWorkers(normalWorker) sw.Stop() ctx := sw.Context() test.That(t, ctx, test.ShouldNotBeNil) @@ -34,7 +32,7 @@ func TestStoppableWorkers(t *testing.T) { }) t.Run("heavy workers", func(t *testing.T) { - sw := utils.NewStoppableWorkers(ctx) + sw := NewStoppableWorkers(ctx) sw.Add(heavyWorker) sw.Add(heavyWorker) sw.Add(heavyWorker) @@ -75,7 +73,7 @@ func TestStoppableWorkers(t *testing.T) { } } - sw := utils.NewBackgroundStoppableWorkers(writeWorker, readWorker) + sw := NewBackgroundStoppableWorkers(writeWorker, readWorker) // Sleep for a second to let concurrent workers do work. time.Sleep(500 * time.Millisecond) sw.Stop() @@ -84,26 +82,51 @@ func TestStoppableWorkers(t *testing.T) { }) t.Run("nested workers", func(t *testing.T) { - sw := utils.NewStoppableWorkers(ctx) + sw := NewStoppableWorkers(ctx) sw.Add(nestedWorkersWorker) sw.Stop() }) t.Run("panicking worker", func(t *testing.T) { - sw := utils.NewStoppableWorkers(ctx) + sw := NewStoppableWorkers(ctx) // Both adding and stopping a panicking worker should cause no `panic`s. sw.Add(panickingWorker) sw.Stop() }) t.Run("already stopped", func(t *testing.T) { - sw := utils.NewStoppableWorkers(ctx) + sw := NewStoppableWorkers(ctx) sw.Stop() sw.Add(normalWorker) // adding after Stop should cause no `panic` sw.Stop() // stopping twice should cause no `panic` }) } +func TestStoppableWorkersWithTicker(t *testing.T) { + timesCalled := 0 + workFn := func(ctx context.Context) { + timesCalled++ + select { + case <-time.After(24 * time.Hour): + t.Log("Failed to observe `Stop` call.") + // Realistically, the go test timeout will be hit and not this `FailNow` call. + t.FailNow() + case <-ctx.Done(): + return + } + } + + // Create a worker with a ticker << the sleep time the test will use. The work function + // increments a counter and hangs. This test will logically assert that: + // - The work function was called exactly once. + // - The work function was passed a context that observed `Stop` was called. + sw := NewStoppableWorkerWithTicker(time.Millisecond, workFn) + time.Sleep(time.Second) + sw.Stop() + + test.That(t, timesCalled, test.ShouldEqual, 1) +} + func normalWorker(ctx context.Context) { for { select { @@ -131,7 +154,7 @@ func heavyWorker(ctx context.Context) { } func nestedWorkersWorker(ctx context.Context) { - nestedSW := utils.NewStoppableWorkers(ctx) + nestedSW := NewStoppableWorkers(ctx) nestedSW.Add(normalWorker) normalWorker(ctx)