-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathworkerpool.go
130 lines (111 loc) · 3.21 KB
/
workerpool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
package wpool
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
)
type config struct {
logger *slog.Logger
channelBufferSize int
}
func defaultConfig() config {
return config{
logger: slog.New(disabledSlogHandler{}),
channelBufferSize: 0,
}
}
func WithChannelBufferSize(s int) func(*config) {
return func(c *config) { c.channelBufferSize = s }
}
func WithLogger(l *slog.Logger) func(*config) {
return func(c *config) { c.logger = l }
}
type WorkerPool[T any] struct {
logger *slog.Logger
ch chan T
stopped chan struct{}
cb func(ctx context.Context, item T)
workersWG sync.WaitGroup
chRWMutex sync.RWMutex
once sync.Once
}
func NewWorkerPool[T any](callback func(ctx context.Context, item T), opts ...func(*config)) *WorkerPool[T] {
c := defaultConfig()
for _, o := range opts {
o(&c)
}
return &WorkerPool[T]{
logger: c.logger,
ch: make(chan T, c.channelBufferSize),
stopped: make(chan struct{}),
cb: callback,
workersWG: sync.WaitGroup{},
chRWMutex: sync.RWMutex{},
once: sync.Once{},
}
}
func (p *WorkerPool[T]) Submit(ctx context.Context, item T) error {
p.chRWMutex.RLock() // acquire read lock to send to ch.
defer p.chRWMutex.RUnlock()
// To avoid writing to a closed p.ch and cause a panic, check if we are stopped first. In the select below, this can
// NOT happen, because p.ch cannot close while we hold the lock.
select {
case <-p.stopped:
return ErrWorkerPoolStopped
case <-ctx.Done():
return fmt.Errorf("worker pool item submission failed due to context cancellation: %w", ctx.Err())
default:
}
select {
case <-p.stopped: // to cover the case where while waiting to send (because the channel is filled), pool stops.
return ErrWorkerPoolStopped
case p.ch <- item:
return nil
case <-ctx.Done():
return fmt.Errorf("worker pool item submission failed due to context cancellation: %w", ctx.Err())
}
}
func (p *WorkerPool[T]) Start(ctx context.Context, numOfWorkers int) {
p.logger.InfoContext(ctx, "worker pool starting", slog.Int("workers_count", numOfWorkers))
if numOfWorkers <= 0 {
return
}
p.workersWG.Add(numOfWorkers)
for i := range numOfWorkers {
go p.worker(ctx, i)
}
}
func (p *WorkerPool[T]) worker(ctx context.Context, id int) {
defer p.workersWG.Done()
for {
select {
case item, open := <-p.ch:
if !open { // Channel has been closed.
p.logger.DebugContext(ctx, "worker channel was closed", slog.Int("worker_id", id))
return
}
p.cb(ctx, item)
case <-ctx.Done(): // Context is done (canceled or deadline exceeded)
p.logger.DebugContext(ctx, "worker context is done", slog.Int("worker_id", id))
go p.Stop(ctx)
return
}
}
}
var ErrWorkerPoolStopped = errors.New("worker pool is stopped")
func (p *WorkerPool[T]) Stop(ctx context.Context) {
p.once.Do(func() { p.close(ctx) })
}
func (p *WorkerPool[T]) close(ctx context.Context) {
p.logger.InfoContext(ctx, "worker pool shutting down")
close(p.stopped) // stop receiving.
func() {
p.chRWMutex.Lock() // acquire write lock to close ch.
defer p.chRWMutex.Unlock()
close(p.ch) // stop accepting.
}()
p.workersWG.Wait() // wait for workers to stop.
p.logger.InfoContext(ctx, "worker pool shutdown completed")
}