diff --git a/pkg/observable/channel/observable.go b/pkg/observable/channel/observable.go index 26958a75b..6c92d29d3 100644 --- a/pkg/observable/channel/observable.go +++ b/pkg/observable/channel/observable.go @@ -2,8 +2,6 @@ package channel import ( "context" - "sync" - "pocket/pkg/observable" ) @@ -13,7 +11,10 @@ import ( // defaultSubscribeBufferSize is the buffer size of a observable's publish channel. const defaultPublishBufferSize = 50 -var _ observable.Observable[any] = (*channelObservable[any])(nil) +var ( + _ observable.Observable[any] = (*channelObservable[any])(nil) + _ observerManager[any] = (*channelObservable[any])(nil) +) // option is a function which receives and can modify the channelObservable state. type option[V any] func(obs *channelObservable[V]) @@ -21,14 +22,14 @@ type option[V any] func(obs *channelObservable[V]) // channelObservable implements the observable.Observable interface and can be notified // by sending on its corresponding publishCh channel. type channelObservable[V any] struct { + // embed observerManager to encapsulate concurrent-safe read/write access to + // observers. This also allows higher-level objects to wrap this observable + // without knowing its specific type by asserting that it implements the + // observerManager interface. + observerManager[V] // publishCh is an observable-wide channel that is used to receive values // which are subsequently fanned out to observers. publishCh chan V - // observersMu protects observers from concurrent access/updates - observersMu *sync.RWMutex - // observers is a list of channelObservers that will be notified when publishCh - // receives a new value. - observers []*channelObserver[V] } // NewObservable creates a new observable which is notified when the publishCh @@ -36,8 +37,7 @@ type channelObservable[V any] struct { func NewObservable[V any](opts ...option[V]) (observable.Observable[V], chan<- V) { // initialize an observable that publishes messages from 1 publishCh to N observers obs := &channelObservable[V]{ - observersMu: &sync.RWMutex{}, - observers: []*channelObserver[V]{}, + observerManager: newObserverManager[V](), } for _, opt := range opts { @@ -64,125 +64,37 @@ func WithPublisher[V any](publishCh chan V) option[V] { } } -// Next synchronously returns the next value from the observable. -func (obsvbl *channelObservable[V]) Next(ctx context.Context) V { - tempObserver := obsvbl.Subscribe(ctx) - defer tempObserver.Unsubscribe() - - val := <-tempObserver.Ch() - return val -} - // Subscribe returns an observer which is notified when the publishCh channel // receives a value. -func (obsvbl *channelObservable[V]) Subscribe(ctx context.Context) observable.Observer[V] { - // must (write) lock observersMu so that we can safely append to the observers list - obsvbl.observersMu.Lock() - defer obsvbl.observersMu.Unlock() - - observer := NewObserver[V](ctx, obsvbl.onUnsubscribe) - obsvbl.observers = append(obsvbl.observers, observer) +func (obs *channelObservable[V]) Subscribe(ctx context.Context) observable.Observer[V] { + // Create a new observer and add it to the list of observers to be notified + // when publishCh receives a new value. + observer := NewObserver[V](ctx, obs.observerManager.remove) + obs.observerManager.add(observer) - // caller can rely on context cancellation or call UnsubscribeAll() to unsubscribe + // caller can rely on context cancelation or call UnsubscribeAll() to unsubscribe // active observers if ctx != nil { // asynchronously wait for the context to be done and then unsubscribe // this observer. - go goUnsubscribeOnDone[V](ctx, observer) + go obs.observerManager.goUnsubscribeOnDone(ctx, observer) } return observer } // UnsubscribeAll unsubscribes and removes all observers from the observable. -func (obsvbl *channelObservable[V]) UnsubscribeAll() { - obsvbl.unsubscribeAll() -} - -// unsubscribeAll unsubscribes and removes all observers from the observable. -func (obsvbl *channelObservable[V]) unsubscribeAll() { - // Copy currentObservers to avoid holding the lock while unsubscribing them. - // The observers at the time of locking, prior to copying, are the canonical - // set of observers which are unsubscribed. - // New or existing Observers may (un)subscribe while the observable is closing. - // Any such observers won't be isClosed but will also stop receiving notifications - // immediately (if they receive any at all). - currentObservers := obsvbl.copyObservers() - for _, observer := range currentObservers { - observer.Unsubscribe() - } - - // Reset observers to an empty list. This purges any observers which might have - // subscribed while the observable was closing. - obsvbl.observersMu.Lock() - obsvbl.observers = []*channelObserver[V]{} - obsvbl.observersMu.Unlock() +func (obs *channelObservable[V]) UnsubscribeAll() { + obs.observerManager.removeAll() } // goPublish to the publishCh and notify observers when values are received. // This function is blocking and should be run in a goroutine. -func (obsvbl *channelObservable[V]) goPublish() { - for notification := range obsvbl.publishCh { - // Copy currentObservers to avoid holding the lock while notifying them. - // New or existing Observers may (un)subscribe while this notification - // is being fanned out. - // The observers at the time of locking, prior to copying, are the canonical - // set of observers which receive this notification. - currentObservers := obsvbl.copyObservers() - for _, obsvr := range currentObservers { - // TODO_CONSIDERATION: perhaps continue trying to avoid making this - // notification async as it would effectively use goroutines - // in memory as a buffer (unbounded). - obsvr.notify(notification) - } +func (obs *channelObservable[V]) goPublish() { + for notification := range obs.publishCh { + obs.observerManager.notifyAll(notification) } // Here we know that the publisher channel has been closed. // Unsubscribe all observers as they can no longer receive notifications. - obsvbl.unsubscribeAll() -} - -// copyObservers returns a copy of the current observers list. It is safe to -// call concurrently. -func (obsvbl *channelObservable[V]) copyObservers() (observers []*channelObserver[V]) { - defer obsvbl.observersMu.RUnlock() - - // This loop blocks on acquiring a read lock on observersMu. If TryRLock - // fails, the loop continues until it succeeds. This is intended to give - // callers a guarantee that this copy operation won't contribute to a deadlock. - for { - // block until a read lock can be acquired - if obsvbl.observersMu.TryRLock() { - break - } - } - - observers = make([]*channelObserver[V], len(obsvbl.observers)) - copy(observers, obsvbl.observers) - - return observers -} - -// goUnsubscribeOnDone unsubscribes from the subscription when the context is done. -// It is a blocking function and intended to be called in a goroutine. -func goUnsubscribeOnDone[V any](ctx context.Context, observer observable.Observer[V]) { - <-ctx.Done() - if observer.IsClosed() { - return - } - observer.Unsubscribe() -} - -// onUnsubscribe returns a function that removes a given observer from the -// observable's list of observers. -func (obsvbl *channelObservable[V]) onUnsubscribe(toRemove *channelObserver[V]) { - // must (write) lock to iterate over and modify the observers list - obsvbl.observersMu.Lock() - defer obsvbl.observersMu.Unlock() - - for i, observer := range obsvbl.observers { - if observer == toRemove { - obsvbl.observers = append((obsvbl.observers)[:i], (obsvbl.observers)[i+1:]...) - break - } - } + obs.observerManager.removeAll() } diff --git a/pkg/observable/channel/observable_test.go b/pkg/observable/channel/observable_test.go index 6ec301cfa..e94679630 100644 --- a/pkg/observable/channel/observable_test.go +++ b/pkg/observable/channel/observable_test.go @@ -337,7 +337,7 @@ func TestChannelObservable_SequentialPublishAndUnsubscription(t *testing.T) { // TODO_TECHDEBT/TODO_INCOMPLETE: add coverage for active observers closing when publishCh closes. func TestChannelObservable_ObserversCloseOnPublishChannelClose(t *testing.T) { - t.Skip("add coverage: all observers should unsubscribeAll when publishCh closes") + t.Skip("add coverage: all observers should unsubscribe when publishCh closes") } func delayedPublishFactory[V any](publishCh chan<- V, delay time.Duration) func(value V) { diff --git a/pkg/observable/channel/observer.go b/pkg/observable/channel/observer.go index 3a2455e64..a989b2092 100644 --- a/pkg/observable/channel/observer.go +++ b/pkg/observable/channel/observer.go @@ -29,9 +29,10 @@ var _ observable.Observer[any] = (*channelObserver[any])(nil) // channelObserver implements the observable.Observer interface. type channelObserver[V any] struct { ctx context.Context - // onUnsubscribe is called in Observer#Unsubscribe, removing the respective - // observer from observers in a concurrency-safe manner. - onUnsubscribe func(toRemove *channelObserver[V]) + // onUnsubscribe is called in Observer#Unsubscribe, closing this observer's + // channel and removing it from the respective obervable's observers list + // in a concurrency-safe manner. + onUnsubscribe func(toRemove observable.Observer[V]) // observerMu protects the observerCh and isClosed fields. observerMu *sync.RWMutex // observerCh is the channel that is used to emit values to the observer. @@ -43,7 +44,7 @@ type channelObserver[V any] struct { isClosed bool } -type UnsubscribeFunc[V any] func(toRemove *channelObserver[V]) +type UnsubscribeFunc[V any] func(toRemove observable.Observer[V]) func NewObserver[V any]( ctx context.Context, diff --git a/pkg/observable/channel/observer_manager.go b/pkg/observable/channel/observer_manager.go new file mode 100644 index 000000000..44807c047 --- /dev/null +++ b/pkg/observable/channel/observer_manager.go @@ -0,0 +1,152 @@ +package channel + +import ( + "context" + "sync" + + "pocket/pkg/observable" +) + +var _ observerManager[any] = (*channelObserverManager[any])(nil) + +// observerManager is an interface intended to be used between an observable and some +// higher-level abstraction and/or observable implementation which would embed it. +// Embedding this interface rather than a channelObservable directly allows for +// more transparency and flexibility in higher-level code. +// NOTE: this interface MUST be used with a common concrete Observer type. +// TODO_CONSIDERATION: Consider whether `observerManager` and `Observable` should remain as separate +// types after some more time and experience using both. +type observerManager[V any] interface { + notifyAll(notification V) + add(toAdd observable.Observer[V]) + remove(toRemove observable.Observer[V]) + removeAll() + goUnsubscribeOnDone(ctx context.Context, observer observable.Observer[V]) +} + +// TODO_CONSIDERATION: if this were a generic implementation, we wouldn't need +// to cast `toAdd` to a channelObserver in add. There are two things +// currently preventing a generic observerManager implementation: +// 1. channelObserver#notify() is not part of the observable.Observer interface +// and is therefore not accessible here. If we move everything into the +// `observable` pkg so that the unexported member is in scope, then the channel +// pkg can't implement it for the same reason, it's an unexported method defined +// in a different pkg. +// 2. == is not defined for a generic Observer type. We would have to add an Equals() +// to the Observer interface. + +// channelObserverManager implements the observerManager interface using +// channelObservers. +type channelObserverManager[V any] struct { + // observersMu protects observers from concurrent access/updates + observersMu *sync.RWMutex + // observers is a list of channelObservers that will be notified when new value + // are received. + observers []*channelObserver[V] +} + +func newObserverManager[V any]() *channelObserverManager[V] { + return &channelObserverManager[V]{ + observersMu: &sync.RWMutex{}, + observers: make([]*channelObserver[V], 0), + } +} + +func (com *channelObserverManager[V]) notifyAll(notification V) { + // Copy currentObservers to avoid holding the lock while notifying them. + // New or existing Observers may (un)subscribe while this notification + // is being fanned out. + // The observers at the time of locking, prior to copying, are the canonical + // set of observers which receive this notification. + currentObservers := com.copyObservers() + for _, obsvr := range currentObservers { + // TODO_TECHDEBT: since this synchronously notifies all observers in a loop, + // it is possible to block here, part-way through notifying all observers, + // on a slow observer consumer (i.e. full buffer). Instead, we should notify + // observers with some limited concurrency of "worker" goroutines. + // The storj/common repo contains such a `Limiter` implementation, see: + // https://github.com/storj/common/blob/main/sync2/limiter.go. + obsvr.notify(notification) + } +} + +// addObserver implements the respective member of observerManager. It is used +// by the channelObservable implementation as well as embedders of observerManager +// (e.g. replayObservable). +// It panics if toAdd is not a channelObserver. +func (com *channelObserverManager[V]) add(toAdd observable.Observer[V]) { + // must (write) lock observersMu so that we can safely append to the observers list + com.observersMu.Lock() + defer com.observersMu.Unlock() + + com.observers = append(com.observers, toAdd.(*channelObserver[V])) +} + +// remove removes a given observer from the observable's list of observers. +// It implements the respective member of observerManager and is used by +// the channelObservable implementation as well as embedders of observerManager +// (e.g. replayObservable). +func (com *channelObserverManager[V]) remove(toRemove observable.Observer[V]) { + // must (write) lock to iterate over and modify the observers list + com.observersMu.Lock() + defer com.observersMu.Unlock() + + for i, observer := range com.observers { + if observer == toRemove { + com.observers = append((com.observers)[:i], (com.observers)[i+1:]...) + break + } + } +} + +// removeAll unsubscribes and removes all observers from the observable. +// It implements the respective member of observerManager and is used by +// the channelObservable implementation as well as embedders of observerManager +// (e.g. replayObservable). +func (com *channelObserverManager[V]) removeAll() { + // Copy currentObservers to avoid holding the lock while unsubscribing them. + // The observers at the time of locking, prior to copying, are the canonical + // set of observers which are unsubscribed. + // New or existing Observers may (un)subscribe while the observable is closing. + // Any such observers won't be isClosed but will also stop receiving notifications + // immediately (if they receive any at all). + currentObservers := com.copyObservers() + for _, observer := range currentObservers { + observer.Unsubscribe() + } + + // Reset observers to an empty list. This purges any observers which might have + // subscribed while the observable was closing. + com.observersMu.Lock() + com.observers = []*channelObserver[V]{} + com.observersMu.Unlock() +} + +// goUnsubscribeOnDone unsubscribes from the subscription when the context is done. +// It is a blocking function and intended to be called in a goroutine. +func (com *channelObserverManager[V]) goUnsubscribeOnDone( + ctx context.Context, + observer observable.Observer[V], +) { + <-ctx.Done() + if observer.IsClosed() { + return + } + observer.Unsubscribe() +} + +// copyObservers returns a copy of the current observers list. It is safe to +// call concurrently. Notably, it is not part of the observerManager interface. +func (com *channelObserverManager[V]) copyObservers() (observers []*channelObserver[V]) { + defer com.observersMu.RUnlock() + + // This loop blocks on acquiring a read lock on observersMu. If TryRLock + // fails, the loop continues until it succeeds. This is intended to give + // callers a guarantee that this copy operation won't contribute to a deadlock. + com.observersMu.RLock() + + observers = make([]*channelObserver[V], len(com.observers)) + copy(observers, com.observers) + + return observers +} diff --git a/pkg/observable/channel/observer_test.go b/pkg/observable/channel/observer_test.go index f8730a422..ccda5c66c 100644 --- a/pkg/observable/channel/observer_test.go +++ b/pkg/observable/channel/observer_test.go @@ -7,20 +7,23 @@ import ( "time" "github.com/stretchr/testify/require" + + "pocket/pkg/observable" ) func TestObserver_Unsubscribe(t *testing.T) { var ( - onUnsubscribeCalled = false publishCh = make(chan int, 1) + onUnsubscribeCalled = false + onUnsubscribe = func(toRemove observable.Observer[int]) { + onUnsubscribeCalled = true + } ) obsvr := &channelObserver[int]{ observerMu: &sync.RWMutex{}, // using a buffered channel to keep the test synchronous - observerCh: publishCh, - onUnsubscribe: func(toRemove *channelObserver[int]) { - onUnsubscribeCalled = true - }, + observerCh: publishCh, + onUnsubscribe: onUnsubscribe, } // should initially be open @@ -37,17 +40,19 @@ func TestObserver_Unsubscribe(t *testing.T) { func TestObserver_ConcurrentUnsubscribe(t *testing.T) { var ( - onUnsubscribeCalled = false publishCh = make(chan int, 1) + onUnsubscribeCalled = false + onUnsubscribe = func(toRemove observable.Observer[int]) { + onUnsubscribeCalled = true + } ) + obsvr := &channelObserver[int]{ ctx: context.Background(), observerMu: &sync.RWMutex{}, // using a buffered channel to keep the test synchronous - observerCh: publishCh, - onUnsubscribe: func(toRemove *channelObserver[int]) { - onUnsubscribeCalled = true - }, + observerCh: publishCh, + onUnsubscribe: onUnsubscribe, } require.Equal(t, false, obsvr.isClosed, "observer channel should initially be open") diff --git a/pkg/observable/channel/replay.go b/pkg/observable/channel/replay.go new file mode 100644 index 000000000..b7c54f877 --- /dev/null +++ b/pkg/observable/channel/replay.go @@ -0,0 +1,237 @@ +package channel + +import ( + "context" + "log" + "sync" + "time" + + "pocket/pkg/observable" +) + +// replayPartialBufferTimeout is the duration to wait for the replay buffer to +// accumulate at least 1 value before returning the accumulated values. +// TODO_CONSIDERATION: perhaps this should be parameterized. +const replayPartialBufferTimeout = 100 * time.Millisecond + +var _ observable.ReplayObservable[any] = (*replayObservable[any])(nil) + +type replayObservable[V any] struct { + // embed observerManager to encapsulate concurrent-safe read/write access to + // observers. This also allows higher-level objects to wrap this observable + // without knowing its specific type by asserting that it implements the + // observerManager interface. + observerManager[V] + // replayBufferSize is the number of notifications to buffer so that they + // can be replayed to new observers. + replayBufferSize int + // replayBufferMu protects replayBuffer from concurrent access/updates. + replayBufferMu sync.RWMutex + // replayBuffer holds the last relayBufferSize number of notifications received + // by this observable. This buffer is replayed to new observers, on subscribing, + // prior to any new notifications being propagated. + replayBuffer []V +} + +// NewReplayObservable returns a new ReplayObservable with the given replay buffer +// replayBufferSize and the corresponding publish channel to notify it of new values. +func NewReplayObservable[V any]( + ctx context.Context, + replayBufferSize int, +) (observable.ReplayObservable[V], chan<- V) { + obsvbl, publishCh := NewObservable[V]() + return ToReplayObservable[V](ctx, replayBufferSize, obsvbl), publishCh +} + +// ToReplayObservable returns an observable which replays the last replayBufferSize +// number of values published to the source observable to new observers, before +// publishing new values. +// It panics if srcObservable does not implement the observerManager interface. +// It should only be used with a srcObservable which contains channelObservers +// (i.e. channelObservable or similar). +func ToReplayObservable[V any]( + ctx context.Context, + replayBufferSize int, + srcObsvbl observable.Observable[V], +) observable.ReplayObservable[V] { + // Assert that the source observable implements the observerMngr required + // to embed and wrap it. + observerMngr := srcObsvbl.(observerManager[V]) + + replayObsvbl := &replayObservable[V]{ + observerManager: observerMngr, + replayBufferSize: replayBufferSize, + replayBuffer: make([]V, 0, replayBufferSize), + } + + srcObserver := srcObsvbl.Subscribe(ctx) + go replayObsvbl.goBufferReplayNotifications(srcObserver) + + return replayObsvbl +} + +// Last synchronously returns the last n values from the replay buffer. It blocks +// until at least 1 notification has been accumulated, then waits replayPartialBufferTimeout +// duration before returning all notifications accumulated notifications by that time. +// If the replay buffer contains at least n notifications, this function will only +// block as long as it takes to accumulate and return them. +// If n is greater than the replay buffer size, the entire replay buffer is returned. +func (ro *replayObservable[V]) Last(ctx context.Context, n int) []V { + // Use a temporary observer to accumulate replay values. + // Subscribe will always start with the replay buffer, so we can safely + // leverage it here for syncrhonization (i.e. blocking until at least 1 + // notification has been accumulated). This also eliminates the need for + // locking and/or copying the replay buffer. + tempObserver := ro.Subscribe(ctx) + defer tempObserver.Unsubscribe() + + // If n is greater than the replay buffer size, return the entire replay buffer. + if n > ro.replayBufferSize { + n = ro.replayBufferSize + log.Printf( + "WARN: requested replay buffer size %d is greater than replay buffer capacity %d; returning entire replay buffer", + n, cap(ro.replayBuffer), + ) + } + + // accumulateReplayValues works concurrently and returns a context and cancelation + // function for signaling completion. + return accumulateReplayValues(tempObserver, n) +} + +// Subscribe returns an observer which is notified when the publishCh channel +// receives a value. +func (ro *replayObservable[V]) Subscribe(ctx context.Context) observable.Observer[V] { + ro.replayBufferMu.RLock() + defer ro.replayBufferMu.RUnlock() + + observer := NewObserver[V](ctx, ro.observerManager.remove) + + // Replay all buffered replayBuffer to the observer channel buffer before + // any new values have an opportunity to send on observerCh (i.e. appending + // observer to ro.observers). + // + // TODO_IMPROVE: this assumes that the observer channel buffer is large enough + // to hold all replay (buffered) notifications. + for _, notification := range ro.replayBuffer { + observer.notify(notification) + } + + ro.observerManager.add(observer) + + // caller can rely on context cancelation or call UnsubscribeAll() to unsubscribe + // active observers + if ctx != nil { + // asynchronously wait for the context to be done and then unsubscribe + // this observer. + go ro.observerManager.goUnsubscribeOnDone(ctx, observer) + } + + return observer +} + +// UnsubscribeAll unsubscribes and removes all observers from the observable. +func (ro *replayObservable[V]) UnsubscribeAll() { + ro.observerManager.removeAll() +} + +// goBufferReplayNotifications buffers the last n notifications from a source +// observer. It is intended to be run in a goroutine. +func (ro *replayObservable[V]) goBufferReplayNotifications(srcObserver observable.Observer[V]) { + for notification := range srcObserver.Ch() { + ro.replayBufferMu.Lock() + // Add the notification to the buffer. + if len(ro.replayBuffer) < ro.replayBufferSize { + ro.replayBuffer = append(ro.replayBuffer, notification) + } else { + // buffer full, make room for the new notification by removing the + // oldest notification. + ro.replayBuffer = append(ro.replayBuffer[1:], notification) + } + ro.replayBufferMu.Unlock() + } +} + +// accumulateReplayValues synchronously (but concurrently) accumulates n values +// from the observer channel into the slice pointed to by accValues and then returns +// said slice. It cancels the context either when n values have been accumulated +// or when at least 1 value has been accumulated and replayPartialBufferTimeout +// has elapsed. +func accumulateReplayValues[V any](observer observable.Observer[V], n int) []V { + var ( + // accValuesMu protects accValues from concurrent access. + accValuesMu sync.Mutex + // Accumulate replay values in a new slice to avoid (read) locking replayBufferMu. + accValues = new([]V) + // canceling the context will cause the loop in the goroutine to exit. + ctx, cancel = context.WithCancel(context.Background()) + ) + + // Concurrently accumulate n values from the observer channel. + go func() { + // Defer canceling the context and unlocking accValuesMu. The function + // assumes that the mutex is locked when it gets execution control back + // from the loop. + defer func() { + cancel() + accValuesMu.Unlock() + }() + for { + // Lock the mutex to read accValues here and potentially write in + // the first case branch in the select below. + accValuesMu.Lock() + + // The context was canceled since the last iteration. + if ctx.Err() != nil { + return + } + + // We've accumulated n values. + if len(*accValues) >= n { + return + } + + // Receive from the observer's channel if we can, otherwise let + // the loop run. + select { + // Receiving from the observer channel blocks if replayBuffer is empty. + case value, ok := <-observer.Ch(): + // tempObserver was closed concurrently. + if !ok { + return + } + + // Update the accumulated values pointed to by accValues. + *accValues = append(*accValues, value) + default: + // If we can't receive from the observer channel immediately, + // let the loop run. + } + + // Unlock accValuesMu so that the select below gets a chance to check + // the length of *accValues to decide whether to cancel, and it can + // be relocked at the top of the loop as it must be locked when the + // loop exits. + accValuesMu.Unlock() + // Wait a tick before continuing the loop. + time.Sleep(time.Millisecond) + } + }() + + // Wait for N values to be accumulated or timeout. When timing out, if we + // have at least 1 value, we can return it. Otherwise, we need to wait for + // the next value to be published (i.e. continue the loop). + for { + select { + case <-ctx.Done(): + return *accValues + case <-time.After(replayPartialBufferTimeout): + accValuesMu.Lock() + if len(*accValues) > 1 { + cancel() + return *accValues + } + accValuesMu.Unlock() + } + } +} diff --git a/pkg/observable/channel/replay_test.go b/pkg/observable/channel/replay_test.go new file mode 100644 index 000000000..a34fb0f92 --- /dev/null +++ b/pkg/observable/channel/replay_test.go @@ -0,0 +1,228 @@ +package channel_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "pocket/internal/testerrors" + "pocket/pkg/observable/channel" +) + +func TestReplayObservable(t *testing.T) { + var ( + replayBufferSize = 3 + values = []int{1, 2, 3, 4, 5} + // the replay buffer is full and has shifted out values with index < + // len(values)-replayBufferSize so Last should return values starting + // from there. + expectedValues = values[len(values)-replayBufferSize:] + errCh = make(chan error, 1) + ctx, cancel = context.WithCancel(context.Background()) + ) + t.Cleanup(cancel) + + // NB: intentionally not using NewReplayObservable() to test ToReplayObservable() directly + // and to retain a reference to the wrapped observable for testing. + obsvbl, publishCh := channel.NewObservable[int]() + replayObsvbl := channel.ToReplayObservable[int](ctx, replayBufferSize, obsvbl) + + // vanilla observer, should be able to receive all values published after subscribing + observer := obsvbl.Subscribe(ctx) + go func() { + for _, expected := range values { + select { + case v := <-observer.Ch(): + if !assert.Equal(t, expected, v) { + errCh <- testerrors.ErrAsync + return + } + case <-time.After(1 * time.Second): + t.Errorf("Did not receive expected value %d in time", expected) + errCh <- testerrors.ErrAsync + return + } + } + }() + + // send all values to the observable's publish channel + for _, value := range values { + publishCh <- value + } + + // allow some time for values to be buffered by the replay observable + time.Sleep(time.Millisecond) + + // replay observer, should receive the last lastN values published prior to + // subscribing followed by subsequently published values + replayObserver := replayObsvbl.Subscribe(ctx) + for _, expected := range expectedValues { + select { + case v := <-replayObserver.Ch(): + require.Equal(t, expected, v) + case <-time.After(1 * time.Second): + t.Fatalf("Did not receive expected value %d in time", expected) + } + } + + // second replay observer, should receive the same values as the first + // even though it subscribed after all values were published and the + // values were already replayed by the first. + replayObserver2 := replayObsvbl.Subscribe(ctx) + for _, expected := range expectedValues { + select { + case v := <-replayObserver2.Ch(): + require.Equal(t, expected, v) + case <-time.After(1 * time.Second): + t.Fatalf("Did not receive expected value %d in time", expected) + } + } +} + +func TestReplayObservable_Last_Full_ReplayBuffer(t *testing.T) { + values := []int{1, 2, 3, 4, 5} + tests := []struct { + name string + replayBufferSize int + // lastN is the number of values to return from the replay buffer + lastN int + expectedValues []int + }{ + { + name: "n < replayBufferSize", + replayBufferSize: 5, + lastN: 3, + // the replay buffer is not full so Last should return values + // starting from the first published value. + expectedValues: values[:3], // []int{1, 2, 3}, + }, + { + name: "n = replayBufferSize", + replayBufferSize: 5, + lastN: 5, + expectedValues: values, + }, + { + name: "n > replayBufferSize", + replayBufferSize: 3, + lastN: 5, + // the replay buffer is full so Last should return values starting + // from lastN - replayBufferSize. + expectedValues: values[2:], // []int{3, 4, 5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ctx = context.Background() + + replayObsvbl, publishCh := + channel.NewReplayObservable[int](ctx, tt.replayBufferSize) + + for _, value := range values { + publishCh <- value + time.Sleep(time.Millisecond) + } + + actualValues := replayObsvbl.Last(ctx, tt.lastN) + require.ElementsMatch(t, tt.expectedValues, actualValues) + }) + } +} + +func TestReplayObservable_Last_Blocks_And_Times_Out(t *testing.T) { + var ( + replayBufferSize = 5 + lastN = 5 + // splitIdx is the index at which this test splits the set of values. + // The two groups of values are published at different points in the + // test to test the behavior of Last under different conditions. + splitIdx = 3 + values = []int{1, 2, 3, 4, 5} + ctx = context.Background() + ) + + replayObsvbl, publishCh := channel.NewReplayObservable[int](ctx, replayBufferSize) + + // getLastValues is a helper function which returns a channel that will + // receive the result of a call to Last, the method under test. + getLastValues := func() chan []int { + lastValuesCh := make(chan []int, 1) + go func() { + // Last should block until lastN values have been published. + // NOTE: this will produce a warning log which can safely be ignored: + // > WARN: requested replay buffer size 3 is greater than replay buffer + // > capacity 3; returning entire replay buffer + lastValuesCh <- replayObsvbl.Last(ctx, lastN) + }() + return lastValuesCh + } + + // Ensure that last blocks when the replay buffer is empty + select { + case actualValues := <-getLastValues(): + t.Fatalf( + "Last should block until at lest 1 value has been published; actualValues: %v", + actualValues, + ) + case <-time.After(200 * time.Millisecond): + } + + // Publish some values (up to splitIdx). + for _, value := range values[:splitIdx] { + publishCh <- value + time.Sleep(time.Millisecond) + } + + // Ensure Last works as expected when n <= len(published_values). + require.ElementsMatch(t, []int{1}, replayObsvbl.Last(ctx, 1)) + require.ElementsMatch(t, []int{1, 2}, replayObsvbl.Last(ctx, 2)) + require.ElementsMatch(t, []int{1, 2, 3}, replayObsvbl.Last(ctx, 3)) + + // Ensure that Last blocks when n > len(published_values) and the replay + // buffer is not full. + select { + case actualValues := <-getLastValues(): + t.Fatalf( + "Last should block until replayPartialBufferTimeout has elapsed; received values: %v", + actualValues, + ) + default: + t.Log("OK: Last is blocking, as expected") + } + + // Ensure that Last returns the correct values when n > len(published_values) + // and the replay buffer is not full. + select { + case actualValues := <-getLastValues(): + require.ElementsMatch(t, values[:splitIdx], actualValues) + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for Last to return") + } + + // Publish the rest of the values (from splitIdx on). + for _, value := range values[splitIdx:] { + publishCh <- value + time.Sleep(time.Millisecond) + } + + // Ensure that Last doesn't block when n = len(published_values) and the + // replay buffer is full. + select { + case actualValues := <-getLastValues(): + require.Len(t, actualValues, lastN) + require.ElementsMatch(t, values, actualValues) + case <-time.After(10 * time.Millisecond): + t.Fatal("timed out waiting for Last to return") + } + + // Ensure that Last still works as expected when n <= len(published_values). + require.ElementsMatch(t, []int{1}, replayObsvbl.Last(ctx, 1)) + require.ElementsMatch(t, []int{1, 2}, replayObsvbl.Last(ctx, 2)) + require.ElementsMatch(t, []int{1, 2, 3}, replayObsvbl.Last(ctx, 3)) + require.ElementsMatch(t, []int{1, 2, 3, 4}, replayObsvbl.Last(ctx, 4)) + require.ElementsMatch(t, []int{1, 2, 3, 4, 5}, replayObsvbl.Last(ctx, 5)) +} diff --git a/pkg/observable/interface.go b/pkg/observable/interface.go index 452c18dcd..d86da414f 100644 --- a/pkg/observable/interface.go +++ b/pkg/observable/interface.go @@ -7,12 +7,18 @@ import "context" // grow, other packages (e.g. https://github.com/ReactiveX/RxGo) can be considered. // (see: https://github.com/ReactiveX/RxGo/pull/377) +// ReplayObservable is an observable which replays the last n values published +// to new observers, before publishing new values to observers. +type ReplayObservable[V any] interface { + Observable[V] + // Last synchronously returns the last n values from the replay buffer. + Last(ctx context.Context, n int) []V +} + // Observable is a generic interface that allows multiple subscribers to be // notified of new values asynchronously. // It is analogous to a publisher in a "Fan-Out" system design. type Observable[V any] interface { - // Next synchronously returns the next value from the observable. - Next(context.Context) V // Subscribe returns an observer which is notified when the publishCh channel // receives a value. Subscribe(context.Context) Observer[V]