diff --git a/go.mod b/go.mod index cb8df3174..7e42ae4a4 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( cosmossdk.io/api v0.3.1 + cosmossdk.io/depinject v1.0.0-alpha.3 cosmossdk.io/errors v1.0.0-beta.7 cosmossdk.io/math v1.0.1 github.com/cometbft/cometbft v0.37.2 @@ -36,7 +37,6 @@ require ( cloud.google.com/go/iam v0.13.0 // indirect cloud.google.com/go/storage v1.29.0 // indirect cosmossdk.io/core v0.5.1 // indirect - cosmossdk.io/depinject v1.0.0-alpha.3 // indirect cosmossdk.io/log v1.1.1-0.20230704160919-88f2c830b0ca // indirect cosmossdk.io/tools/rosetta v0.2.1 // indirect filippo.io/edwards25519 v1.0.0 // indirect diff --git a/internal/testclient/testblock/client.go b/internal/testclient/testblock/client.go new file mode 100644 index 000000000..0d0f1b78b --- /dev/null +++ b/internal/testclient/testblock/client.go @@ -0,0 +1,27 @@ +package testblock + +import ( + "context" + "testing" + + "cosmossdk.io/depinject" + "github.com/stretchr/testify/require" + + "pocket/internal/testclient" + "pocket/internal/testclient/testeventsquery" + "pocket/pkg/client" + "pocket/pkg/client/block" +) + +func NewLocalnetClient(ctx context.Context, t *testing.T) client.BlockClient { + t.Helper() + + queryClient := testeventsquery.NewLocalnetClient(t) + require.NotNil(t, queryClient) + + deps := depinject.Supply(queryClient) + bClient, err := block.NewBlockClient(ctx, deps, testclient.CometLocalWebsocketURL) + require.NoError(t, err) + + return bClient +} diff --git a/pkg/client/block/block.go b/pkg/client/block/block.go new file mode 100644 index 000000000..5fe9a2e1e --- /dev/null +++ b/pkg/client/block/block.go @@ -0,0 +1,44 @@ +package block + +import ( + "encoding/json" + + "github.com/cometbft/cometbft/types" + + "pocket/pkg/client" +) + +// cometBlockEvent is used to deserialize incoming committed block event messages +// from the respective events query subscription. It implements the client.Block +// interface by loosely wrapping cometbft's block type, into which messages are +// deserialized. +type cometBlockEvent struct { + Block types.Block `json:"block"` +} + +// Height returns the block's height. +func (blockEvent *cometBlockEvent) Height() int64 { + return blockEvent.Block.Height +} + +// Hash returns the binary representation of the block's hash as a byte slice. +func (blockEvent *cometBlockEvent) Hash() []byte { + return blockEvent.Block.LastBlockID.Hash.Bytes() +} + +// newCometBlockEvent attempts to deserialize the given bytes into a comet block. +// if the resulting block has a height of zero, assume the event was not a block +// event and return an ErrUnmarshalBlockEvent error. +func newCometBlockEvent(blockMsgBz []byte) (client.Block, error) { + blockMsg := new(cometBlockEvent) + if err := json.Unmarshal(blockMsgBz, blockMsg); err != nil { + return nil, err + } + + // If msg does not match the expected format then the block's height has a zero value. + if blockMsg.Block.Header.Height == 0 { + return nil, ErrUnmarshalBlockEvent.Wrap(string(blockMsgBz)) + } + + return blockMsg, nil +} diff --git a/pkg/client/block/client.go b/pkg/client/block/client.go new file mode 100644 index 000000000..387f5f16b --- /dev/null +++ b/pkg/client/block/client.go @@ -0,0 +1,209 @@ +package block + +import ( + "context" + "fmt" + "time" + + "cosmossdk.io/depinject" + + "pocket/pkg/client" + "pocket/pkg/either" + "pocket/pkg/observable" + "pocket/pkg/observable/channel" + "pocket/pkg/retry" +) + +const ( + // eventsBytesRetryDelay is the delay between retry attempts when the events + // bytes observable returns an error. + eventsBytesRetryDelay = time.Second + // eventsBytesRetryLimit is the maximum number of times to attempt to + // re-establish the events query bytes subscription when the events bytes + // observable returns an error. + eventsBytesRetryLimit = 10 + eventsBytesRetryResetTimeout = 10 * time.Second + // NB: cometbft event subscription query for newly committed blocks. + // (see: https://docs.cosmos.network/v0.47/core/events#subscribing-to-events) + committedBlocksQuery = "tm.event='NewBlock'" + // latestBlockObsvblsReplayBufferSize is the replay buffer size of the + // latestBlockObsvbls replay observable which is used to cache the latest block observable. + // It is updated with a new "active" observable when a new + // events query subscription is created, for example, after a non-persistent + // connection error. + latestBlockObsvblsReplayBufferSize = 1 + // latestBlockReplayBufferSize is the replay buffer size of the latest block + // replay observable which is notified when block commit events are received + // by the events query client subscription created in goPublishBlocks. + latestBlockReplayBufferSize = 1 +) + +var ( + _ client.BlockClient = (*blockClient)(nil) + _ client.Block = (*cometBlockEvent)(nil) +) + +// blockClient implements the BlockClient interface. +type blockClient struct { + // endpointURL is the URL of RPC endpoint which eventsClient subscription + // requests will be sent. + endpointURL string + // eventsClient is the events query client which is used to subscribe to + // newly committed block events. It emits an either value which may contain + // an error, at most, once and closes immediately after if it does. + eventsClient client.EventsQueryClient + // latestBlockObsvbls is a replay observable with replay buffer size 1, + // which holds the "active latest block observable" which is notified when + // block commit events are received by the events query client subscription + // created in goPublishBlocks. This observable (and the one it emits) closes + // when the events bytes observable returns an error and is updated with a + // new "active" observable after a new events query subscription is created. + latestBlockObsvbls observable.ReplayObservable[client.BlocksObservable] + // latestBlockObsvblsReplayPublishCh is the publish channel for latestBlockObsvbls. + // It's used to set blockObsvbl initially and subsequently update it, for + // example, when the connection is re-established after erroring. + latestBlockObsvblsReplayPublishCh chan<- client.BlocksObservable +} + +// eventsBytesToBlockMapFn is a convenience type to represent the type of a +// function which maps event subscription message bytes into block event objects. +// This is used as a transformFn in a channel.Map() call and is the type returned +// by the newEventsBytesToBlockMapFn factory function. +type eventBytesToBlockMapFn func(either.Either[[]byte]) (client.Block, bool) + +// NewBlockClient creates a new block client from the given dependencies and cometWebsocketURL. +func NewBlockClient( + ctx context.Context, + deps depinject.Config, + cometWebsocketURL string, +) (client.BlockClient, error) { + // Initialize block client + bClient := &blockClient{endpointURL: cometWebsocketURL} + bClient.latestBlockObsvbls, bClient.latestBlockObsvblsReplayPublishCh = + channel.NewReplayObservable[client.BlocksObservable](ctx, latestBlockObsvblsReplayBufferSize) + + // Inject dependencies + if err := depinject.Inject(deps, &bClient.eventsClient); err != nil { + return nil, err + } + + // Concurrently publish blocks to the observable emitted by latestBlockObsvbls. + go bClient.goPublishBlocks(ctx) + + return bClient, nil +} + +// CommittedBlocksSequence returns a ReplayObservable, with a replay buffer size +// of 1, which is notified when block commit events are received by the events +// query subscription. +func (bClient *blockClient) CommittedBlocksSequence(ctx context.Context) client.BlocksObservable { + // Get the latest block observable from the replay observable. We only ever + // want the last 1 as any prior latest block observable values are closed. + // Directly accessing the zeroth index here is safe because the call to Last + // is guaranteed to return a slice with at least 1 element. + return bClient.latestBlockObsvbls.Last(ctx, 1)[0] +} + +// LatestBlock returns the latest committed block that's been received by the +// corresponding events query subscription. +// It blocks until at least one block event has been received. +func (bClient *blockClient) LatestBlock(ctx context.Context) client.Block { + return bClient.CommittedBlocksSequence(ctx).Last(ctx, 1)[0] +} + +// Close unsubscribes all observers of the committed blocks sequence observable +// and closes the events query client. +func (bClient *blockClient) Close() { + // Closing eventsClient will cascade unsubscribe and close downstream observers. + bClient.eventsClient.Close() +} + +// goPublishBlocks runs the work function returned by retryPublishBlocksFactory, +// re-invoking it according to the arguments to retry.OnError when the events bytes +// observable returns an asynchronous error. +// This function is intended to be called in a goroutine. +func (bClient *blockClient) goPublishBlocks(ctx context.Context) { + // React to errors by getting a new events bytes observable, re-mapping it, + // and send it to latestBlockObsvblsReplayPublishCh such that + // latestBlockObsvbls.Last(ctx, 1) will return it. + publishErr := retry.OnError( + ctx, + eventsBytesRetryLimit, + eventsBytesRetryDelay, + eventsBytesRetryResetTimeout, + "goPublishBlocks", + bClient.retryPublishBlocksFactory(ctx), + ) + + // If we get here, the retry limit was reached and the retry loop exited. + // Since this function runs in a goroutine, we can't return the error to the + // caller. Instead, we panic. + panic(fmt.Errorf("BlockClient.goPublishBlocks shold never reach this spot: %w", publishErr)) +} + +// retryPublishBlocksFactory returns a function which is intended to be passed to +// retry.OnError. The returned function pipes event bytes from the events query +// client, maps them to block events, and publishes them to the latestBlockObsvbls +// replay observable. +func (bClient *blockClient) retryPublishBlocksFactory(ctx context.Context) func() chan error { + return func() chan error { + errCh := make(chan error, 1) + eventsBzObsvbl, err := bClient.eventsClient.EventsBytes(ctx, committedBlocksQuery) + if err != nil { + errCh <- err + return errCh + } + + // NB: must cast back to generic observable type to use with Map. + // client.BlocksObservable is only used to workaround gomock's lack of + // support for generic types. + eventsBz := observable.Observable[either.Either[[]byte]](eventsBzObsvbl) + blockEventFromEventBz := newEventsBytesToBlockMapFn(errCh) + blocksObsvbl := channel.MapReplay(ctx, latestBlockReplayBufferSize, eventsBz, blockEventFromEventBz) + + // Initially set latestBlockObsvbls and update if after retrying on error. + bClient.latestBlockObsvblsReplayPublishCh <- blocksObsvbl + + return errCh + } +} + +// newEventsBytesToBlockMapFn is a factory for a function which is intended +// to be used as a transformFn in a channel.Map() call. Since the map function +// is called asynchronously, this factory creates a closure around an error channel +// which can be used for asynchronous error signaling from within the map function, +// and handling from the Map call context. +// +// The map function itself attempts to deserialize the given byte slice as a +// committed block event. If the events bytes observable contained an error, this value is not emitted +// (skipped) on the destination observable of the map operation. +// If deserialization failed because the event bytes were for a different event type, +// this value is also skipped. +// If deserialization failed for some other reason, this function panics. +func newEventsBytesToBlockMapFn(errCh chan<- error) eventBytesToBlockMapFn { + return func(eitherEventBz either.Either[[]byte]) (_ client.Block, skip bool) { + eventBz, err := eitherEventBz.ValueOrError() + if err != nil { + errCh <- err + // Don't publish (skip) if eitherEventBz contained an error. + // eitherEventBz should automatically close itself in this case. + // (i.e. no more values should be mapped to this transformFn's respective + // dstObservable). + return nil, true + } + + block, err := newCometBlockEvent(eventBz) + if err != nil { + if ErrUnmarshalBlockEvent.Is(err) { + // Don't publish (skip) if the message was not a block event. + return nil, true + } + + panic(fmt.Sprintf( + "unexpected error deserializing block event: %s; eventBz: %s", + err, string(eventBz), + )) + } + return block, false + } +} diff --git a/pkg/client/block/client_integration_test.go b/pkg/client/block/client_integration_test.go new file mode 100644 index 000000000..4f51d7873 --- /dev/null +++ b/pkg/client/block/client_integration_test.go @@ -0,0 +1,77 @@ +//go:build integration + +package block_test + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "pocket/internal/testclient/testblock" + "pocket/pkg/client" +) + +const blockIntegrationSubTimeout = 5 * time.Second + +func TestBlockClient_LatestBlock(t *testing.T) { + ctx := context.Background() + + blockClient := testblock.NewLocalnetClient(ctx, t) + require.NotNil(t, blockClient) + + block := blockClient.LatestBlock(ctx) + require.NotEmpty(t, block) +} + +func TestBlockClient_BlocksObservable(t *testing.T) { + ctx := context.Background() + + blockClient := testblock.NewLocalnetClient(ctx, t) + require.NotNil(t, blockClient) + + blockSub := blockClient.CommittedBlocksSequence(ctx).Subscribe(ctx) + + var ( + blockMu sync.Mutex + blockCounter int + blocksToRecv = 2 + errCh = make(chan error, 1) + ) + go func() { + var previousBlock client.Block + for block := range blockSub.Ch() { + if previousBlock != nil { + if !assert.Equal(t, previousBlock.Height()+1, block.Height()) { + errCh <- fmt.Errorf("expected block height %d, got %d", previousBlock.Height()+1, block.Height()) + return + } + } + previousBlock = block + + require.NotEmpty(t, block) + blockMu.Lock() + blockCounter++ + if blockCounter >= blocksToRecv { + errCh <- nil + return + } + blockMu.Unlock() + } + }() + + select { + case err := <-errCh: + require.NoError(t, err) + require.Equal(t, blocksToRecv, blockCounter) + case <-time.After(blockIntegrationSubTimeout): + t.Fatalf( + "timed out waiting for block subscription; expected %d blocks, got %d", + blocksToRecv, blockCounter, + ) + } +} diff --git a/pkg/client/block/client_test.go b/pkg/client/block/client_test.go new file mode 100644 index 000000000..c787f5ad2 --- /dev/null +++ b/pkg/client/block/client_test.go @@ -0,0 +1,138 @@ +package block_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + "cosmossdk.io/depinject" + comettypes "github.com/cometbft/cometbft/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + "pocket/internal/testclient" + "pocket/internal/testclient/testeventsquery" + "pocket/pkg/client" + "pocket/pkg/client/block" + eventsquery "pocket/pkg/client/events_query" +) + +const blockAssertionLoopTimeout = 500 * time.Millisecond + +func TestBlockClient(t *testing.T) { + var ( + expectedHeight = int64(1) + expectedHash = []byte("test_hash") + expectedBlockEvent = &testBlockEvent{ + Block: comettypes.Block{ + Header: comettypes.Header{ + Height: 1, + Time: time.Now(), + LastBlockID: comettypes.BlockID{ + Hash: expectedHash, + }, + }, + }, + } + ctx = context.Background() + ) + + // Set up a mock connection and dialer which are expected to be used once. + connMock, dialerMock := testeventsquery.NewOneTimeMockConnAndDialer(t) + connMock.EXPECT().Send(gomock.Any()).Return(nil).Times(1) + // Mock the Receive method to return the expected block event. + connMock.EXPECT().Receive().DoAndReturn(func() ([]byte, error) { + blockEventJson, err := json.Marshal(expectedBlockEvent) + require.NoError(t, err) + return blockEventJson, nil + }).AnyTimes() + + // Set up events query client dependency. + dialerOpt := eventsquery.WithDialer(dialerMock) + eventsQueryClient := testeventsquery.NewLocalnetClient(t, dialerOpt) + deps := depinject.Supply(eventsQueryClient) + + // Set up block client. + blockClient, err := block.NewBlockClient(ctx, deps, testclient.CometLocalWebsocketURL) + require.NoError(t, err) + require.NotNil(t, blockClient) + + // Run LatestBlock and CommittedBlockSequence concurrently because they can + // block, leading to an unresponsive test. This function sends multiple values + // on the actualBlockCh which are all asserted against in blockAssertionLoop. + // If any of the methods under test hang, the test will time out. + var ( + actualBlockCh = make(chan client.Block, 1) + done = make(chan struct{}, 1) + ) + go func() { + // Test LatestBlock method. + actualBlock := blockClient.LatestBlock(ctx) + require.Equal(t, expectedHeight, actualBlock.Height()) + require.Equal(t, expectedHash, actualBlock.Hash()) + + // Test CommittedBlockSequence method. + blockObservable := blockClient.CommittedBlocksSequence(ctx) + require.NotNil(t, blockObservable) + + // Ensure that the observable is replayable via Last. + actualBlockCh <- blockObservable.Last(ctx, 1)[0] + + // Ensure that the observable is replayable via Subscribe. + blockObserver := blockObservable.Subscribe(ctx) + for block := range blockObserver.Ch() { + actualBlockCh <- block + break + } + + // Signal test completion + done <- struct{}{} + }() + + // blockAssertionLoop ensures that the blocks retrieved from both LatestBlock + // method and CommittedBlocksSequence method match the expected block height + // and hash. This loop waits for blocks to be sent on the actualBlockCh channel + // by the methods being tested. Once the methods are done, they send a signal on + // the "done" channel. If the blockAssertionLoop doesn't receive any block or + // the done signal within a specific timeout, it assumes something has gone wrong + // and fails the test. +blockAssertionLoop: + for { + select { + case actualBlock := <-actualBlockCh: + require.Equal(t, expectedHeight, actualBlock.Height()) + require.Equal(t, expectedHash, actualBlock.Hash()) + case <-done: + break blockAssertionLoop + case <-time.After(blockAssertionLoopTimeout): + t.Fatal("timed out waiting for block event") + } + } + + // Wait a tick for the observables to be set up. + time.Sleep(time.Millisecond) + + blockClient.Close() +} + +/* +TODO_TECHDEBT/TODO_CONSIDERATION(#XXX): this duplicates the unexported block event + +type from pkg/client/block/block.go. We seem to have some conflicting preferences +which result in the need for this duplication until a preferred direction is +identified: + + - We should prefer tests being in their own pkgs (e.g. block_test) + - this would resolve if this test were in the block package instead. + - We should prefer to not export types which don't require exporting for API + consumption. + - This test is the only external (to the block pkg) dependency of cometBlockEvent. + - We could use the //go:build test constraint on a new file which exports it + for testing purposes. + - This would imply that we also add -tags=test to all applicable tooling + and add a test which fails if the tag is absent. +*/ +type testBlockEvent struct { + Block comettypes.Block `json:"block"` +} diff --git a/pkg/client/block/errors.go b/pkg/client/block/errors.go new file mode 100644 index 000000000..0a0cc28c9 --- /dev/null +++ b/pkg/client/block/errors.go @@ -0,0 +1,8 @@ +package block + +import errorsmod "cosmossdk.io/errors" + +var ( + ErrUnmarshalBlockEvent = errorsmod.Register(codespace, 1, "failed to unmarshal committed block event") + codespace = "block_client" +) diff --git a/pkg/client/interface.go b/pkg/client/interface.go index bd811a153..1007656a5 100644 --- a/pkg/client/interface.go +++ b/pkg/client/interface.go @@ -1,4 +1,4 @@ -//go:generate mockgen -destination=../../internal/mocks/mockclient/query_client_mock.go -package=mockclient . Dialer,Connection +//go:generate mockgen -destination=../../internal/mocks/mockclient/events_query_client_mock.go -package=mockclient . Dialer,Connection,EventsQueryClient package client @@ -9,6 +9,28 @@ import ( "pocket/pkg/observable" ) +// BlocksObservable is an observable which is notified with an either +// value which contains either an error or the event message bytes. +// TODO_HACK: The purpose of this type is to work around gomock's lack of +// support for generic types. For the same reason, this type cannot be an +// alias (i.e. EventsBytesObservable = observable.Observable[either.Either[[]byte]]). +type BlocksObservable observable.ReplayObservable[Block] + +type BlockClient interface { + // Blocks returns an observable which emits newly committed blocks. + CommittedBlocksSequence(context.Context) BlocksObservable + // LatestBlock returns the latest block that has been committed. + LatestBlock(context.Context) Block + // Close unsubscribes all observers of the committed blocks sequence observable + // and closes the events query client. + Close() +} + +type Block interface { + Height() int64 + Hash() []byte +} + // TODO_CONSIDERATION: the cosmos-sdk CLI code seems to use a cometbft RPC client // which includes a `#Subscribe()` method for a similar purpose. Perhaps we could // replace this custom websocket client with that. diff --git a/pkg/observable/channel/map.go b/pkg/observable/channel/map.go index 942859e50..912043ca9 100644 --- a/pkg/observable/channel/map.go +++ b/pkg/observable/channel/map.go @@ -6,6 +6,8 @@ import ( "pocket/pkg/observable" ) +type MapFn[S, D any] func(src S) (dst D, skip bool) + // Map transforms the given observable by applying the given transformFn to each // notification received from the observable. If the transformFn returns a skip // bool of true, the notification is skipped and not emitted to the resulting @@ -14,7 +16,7 @@ func Map[S, D any]( ctx context.Context, srcObservable observable.Observable[S], // TODO_CONSIDERATION: if this were variadic, it could simplify serial transformations. - transformFn func(src S) (dst D, skip bool), + transformFn MapFn[S, D], ) observable.Observable[D] { dstObservable, dstProducer := NewObservable[D]() srcObserver := srcObservable.Subscribe(ctx) @@ -32,3 +34,33 @@ func Map[S, D any]( return dstObservable } + +// MapReplay transforms the given observable by applying the given transformFn to +// each notification received from the observable. If the transformFn returns a +// skip bool of true, the notification is skipped and not emitted to the resulting +// observable. +// The resulting observable will receive the last replayBufferSize +// number of values published to the source observable before receiving new values. +func MapReplay[S, D any]( + ctx context.Context, + replayBufferSize int, + srcObservable observable.Observable[S], + // TODO_CONSIDERATION: if this were variadic, it could simplify serial transformations. + transformFn func(src S) (dst D, skip bool), +) observable.ReplayObservable[D] { + dstObservable, dstProducer := NewReplayObservable[D](ctx, replayBufferSize) + srcObserver := srcObservable.Subscribe(ctx) + + go func() { + for srcNotification := range srcObserver.Ch() { + dstNotification, skip := transformFn(srcNotification) + if skip { + continue + } + + dstProducer <- dstNotification + } + }() + + return dstObservable +} diff --git a/pkg/retry/retry.go b/pkg/retry/retry.go new file mode 100644 index 000000000..f76a7dcf5 --- /dev/null +++ b/pkg/retry/retry.go @@ -0,0 +1,72 @@ +package retry + +import ( + "context" + "log" + "time" +) + +type RetryFunc func() chan error + +// OnError continuously invokes the provided work function (workFn) until either the context (ctx) +// is canceled or the error channel returned by workFn is closed. If workFn encounters an error, +// OnError will retry invoking workFn based on the provided retry parameters. +// +// Parameters: +// - ctx: the context to monitor for cancellation. If canceled, OnError will exit without error. +// - retryLimit: the maximum number of retries for workFn upon encountering an error. +// - retryDelay: the duration to wait before retrying workFn after an error. +// - retryResetCount: Specifies the duration of continuous error-free operation required +// before the retry count is reset. If the work function operates without +// errors for this duration, any subsequent error will restart the retry +// count from the beginning. +// - workName: a name or descriptor for the work function, used for logging purposes. +// - workFn: a function that performs some work and returns an error channel. +// This channel emits errors encountered during the work. +// +// Returns: +// - If the context is canceled, the function returns nil. +// - If the error channel is closed, a warning is logged, and the function returns nil. +// - If the retry limit is reached, the function returns the error from the channel. +// +// Note: After each error, a delay specified by retryDelay is introduced before retrying workFn.func OnError( +func OnError( + ctx context.Context, + retryLimit int, + retryDelay time.Duration, + retryResetTimeout time.Duration, + workName string, + workFn RetryFunc, +) error { + var retryCount int + errCh := workFn() + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(retryResetTimeout): + retryCount = 0 + case err, ok := <-errCh: + // Exit the retry loop if the error channel is closed. + if !ok { + log.Printf( + "WARN: error channel for %s closed, will no longer retry on error\n", + workName, + ) + return nil + } + + if retryCount >= retryLimit { + return err + } + + // Wait retryDelay before retrying. + time.Sleep(retryDelay) + + // Increment retryCount and retry workFn. + retryCount++ + errCh = workFn() + log.Printf("ERROR: retrying %s after error: %s\n", workName, err) + } + } +} diff --git a/pkg/retry/retry_test.go b/pkg/retry/retry_test.go new file mode 100644 index 000000000..8a1154c30 --- /dev/null +++ b/pkg/retry/retry_test.go @@ -0,0 +1,337 @@ +package retry_test + +/* TODO_TECHDEBT: improve this test: +- fix race condition around the logOutput buffer +- factor our common setup and assertion code +- drive out flakiness +- improve comments +*/ + +import ( + "bytes" + "context" + "fmt" + "log" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "pocket/pkg/retry" +) + +var testErr = fmt.Errorf("test error") + +// TestOnError verifies the behavior of the OnError function in the retry package. +// It ensures that the function correctly retries a failing operation for a specified +// number of times with the expected delay between retries. +func TestOnError(t *testing.T) { + t.Skip("TODO_TECHDEBT: this test should pass but contains a race condition around the logOutput buffer") + + // Setting up the test variables. + var ( + // logOutput captures the log output for verification of logged messages. + logOutput bytes.Buffer + // expectedRetryDelay is the duration we expect between retries. + expectedRetryDelay = time.Millisecond + // expectedRetryLimit is the maximum number of retries the test expects. + expectedRetryLimit = 5 + // retryResetTimeout is the duration after which the retry count should reset. + retryResetTimeout = time.Second + // testFnCallCount keeps track of how many times the test function is called. + testFnCallCount int32 + // testFnCallTimeCh is a channel receives a time.Time each when the test + // function is called. + testFnCallTimeCh = make(chan time.Time, expectedRetryLimit) + ctx = context.Background() + ) + + // Redirect the standard logger's output to our custom buffer for later verification. + log.SetOutput(&logOutput) + + // Define testFn, a function that simulates a failing operation and logs its invocation times. + testFn := func() chan error { + // Record the current time to track the delay between retries. + testFnCallTimeCh <- time.Now() + + // Create a channel to return an error, simulating a failing operation. + errCh := make(chan error, 1) + errCh <- testErr + + // Increment the call count safely across goroutine boundaries. + atomic.AddInt32(&testFnCallCount, 1) + + return errCh + } + + // Create a channel to receive the error result from the OnError function. + retryOnErrorErrCh := make(chan error, 1) + + // Start the OnError function in a separate goroutine, simulating concurrent operation. + go func() { + // Call the OnError function with the test parameters and function. + retryOnErrorErrCh <- retry.OnError( + ctx, + expectedRetryLimit, + expectedRetryDelay, + retryResetTimeout, + "TestOnError", + testFn, + ) + }() + + // Calculate the total expected time for all retries to complete. + totalExpectedDelay := expectedRetryDelay * time.Duration(expectedRetryLimit) + // Wait for the OnError function to execute and retry the expected number of times. + time.Sleep(totalExpectedDelay + 100*time.Millisecond) + + // Verify that the test function was called the expected number of times. + require.Equal(t, expectedRetryLimit, int(testFnCallCount), "Test function was not called the expected number of times") + + // Verify the delay between retries of the test function. + var prevCallTime time.Time + for i := 0; i < expectedRetryLimit; i++ { + // Retrieve the next function call time from the channel. + nextCallTime, ok := <-testFnCallTimeCh + if !ok { + t.Fatalf("expected %d calls to testFn, but channel closed after %d", expectedRetryLimit, i) + } + + // For all calls after the first, check that the delay since the previous call meets expectations. + if i != 0 { + actualRetryDelay := nextCallTime.Sub(prevCallTime) + require.GreaterOrEqual(t, actualRetryDelay, expectedRetryDelay, "Retry delay was less than expected") + } + + // Update prevCallTime for the next iteration. + prevCallTime = nextCallTime + } + + // Verify that the OnError function returned the expected error. + select { + case err := <-retryOnErrorErrCh: + require.ErrorIs(t, err, testErr, "OnError did not return the expected error") + case <-time.After(100 * time.Millisecond): + t.Fatal("expected error from OnError, but none received") + } + + // Verify the error messages logged during the retries. + expectedErrorLine := "ERROR: retrying TestOnError after error: test error" + trimmedLogOutput := strings.Trim(logOutput.String(), "\n") + logOutputLines := strings.Split(trimmedLogOutput, "\n") + require.Lenf(t, logOutputLines, expectedRetryLimit, "unexpected number of log lines") + for _, line := range logOutputLines { + require.Contains(t, line, expectedErrorLine, "log line does not contain the expected prefix") + } +} + +// TODO_TECHDEBT: assert that the retry loop exits when the context is closed. +func TestOnError_ExitsWhenCtxCloses(t *testing.T) { + t.SkipNow() +} + +func TestOnError_ExitsWhenErrChCloses(t *testing.T) { + t.Skip("TODO_TECHDEBT: this test should pass but contains a race condition around the logOutput buffer") + + // Setup test variables and log capture + var ( + logOutput bytes.Buffer + testFnCallCount int32 + expectedRetryDelay = time.Millisecond + expectedRetryLimit = 3 + retryLimit = 5 + retryResetTimeout = time.Second + testFnCallTimeCh = make(chan time.Time, expectedRetryLimit) + ctx = context.Background() + ) + + // Redirect the log output for verification later + log.SetOutput(&logOutput) + + // Define the test function that simulates an error and counts its invocations + testFn := func() chan error { + atomic.AddInt32(&testFnCallCount, 1) // Increment the invocation count atomically + testFnCallTimeCh <- time.Now() // Track the invocation time + + errCh := make(chan error, 1) + if atomic.LoadInt32(&testFnCallCount) >= int32(expectedRetryLimit) { + close(errCh) + return errCh + } + + errCh <- testErr + return errCh + } + + retryOnErrorErrCh := make(chan error, 1) + // Spawn a goroutine to test the OnError function + go func() { + retryOnErrorErrCh <- retry.OnError( + ctx, + retryLimit, + expectedRetryDelay, + retryResetTimeout, + "TestOnError_ExitsWhenErrChCloses", + testFn, + ) + }() + + // Wait for the OnError function to execute and retry the expected number of times + totalExpectedDelay := expectedRetryDelay * time.Duration(expectedRetryLimit) + time.Sleep(totalExpectedDelay + 100*time.Millisecond) + + // Assert that the test function was called the expected number of times + require.Equal(t, expectedRetryLimit, int(testFnCallCount)) + + // Assert that the retry delay between function calls matches the expected delay + var prevCallTime = new(time.Time) + for i := 0; i < expectedRetryLimit; i++ { + select { + case nextCallTime := <-testFnCallTimeCh: + if i != 0 { + actualRetryDelay := nextCallTime.Sub(*prevCallTime) + require.GreaterOrEqual(t, actualRetryDelay, expectedRetryDelay) + } + + *prevCallTime = nextCallTime + default: + t.Fatalf( + "expected %d calls to testFn, but only received %d", + expectedRetryLimit, i+1, + ) + } + } + + select { + case err := <-retryOnErrorErrCh: + require.NoError(t, err) + case <-time.After(100 * time.Millisecond): + t.Fatalf("expected error from OnError, but none received") + } + + // Verify the logged error messages + var ( + logOutputLines = strings.Split(strings.Trim(logOutput.String(), "\n"), "\n") + errorLines = logOutputLines[:len(logOutputLines)-1] + warnLine = logOutputLines[len(logOutputLines)-1] + expectedWarnMsg = "WARN: error channel for TestOnError_ExitsWhenErrChCloses closed, will no longer retry on error" + expectedErrorMsg = "ERROR: retrying TestOnError_ExitsWhenErrChCloses after error: test error" + ) + + require.Lenf( + t, logOutputLines, + expectedRetryLimit, + "expected %d log lines, got %d", + expectedRetryLimit, len(logOutputLines), + ) + for _, line := range errorLines { + require.Contains(t, line, expectedErrorMsg) + } + require.Contains(t, warnLine, expectedWarnMsg) +} + +// assert that retryCount resets on success +func TestOnError_RetryCountResetTimeout(t *testing.T) { + t.Skip("TODO_TECHDEBT: this test should pass but contains a race condition around the logOutput buffer") + + // Setup test variables and log capture + var ( + logOutput bytes.Buffer + testFnCallCount int32 + expectedRetryDelay = time.Millisecond + expectedRetryLimit = 9 + retryLimit = 5 + retryResetTimeout = 3 * time.Millisecond + testFnCallTimeCh = make(chan time.Time, expectedRetryLimit) + ctx = context.Background() + ) + + // Redirect the log output for verification later + log.SetOutput(&logOutput) + + // Define the test function that simulates an error and counts its invocations + testFn := func() chan error { + // Track the invocation time + testFnCallTimeCh <- time.Now() + + errCh := make(chan error, 1) + + count := atomic.LoadInt32(&testFnCallCount) + if count == int32(retryLimit) { + go func() { + time.Sleep(retryResetTimeout) + errCh <- testErr + }() + } else { + errCh <- testErr + } + + // Increment the invocation count atomically + atomic.AddInt32(&testFnCallCount, 1) + return errCh + } + + retryOnErrorErrCh := make(chan error, 1) + // Spawn a goroutine to test the OnError function + go func() { + retryOnErrorErrCh <- retry.OnError( + ctx, + retryLimit, + expectedRetryDelay, + retryResetTimeout, + "TestOnError", + testFn, + ) + }() + + // Wait for the OnError function to execute and retry the expected number of times + totalExpectedDelay := expectedRetryDelay * time.Duration(expectedRetryLimit) + time.Sleep(totalExpectedDelay + 100*time.Millisecond) + + // Assert that the test function was called the expected number of times + require.Equal(t, expectedRetryLimit, int(testFnCallCount)) + + // Assert that the retry delay between function calls matches the expected delay + var prevCallTime = new(time.Time) + for i := 0; i < expectedRetryLimit; i++ { + select { + case nextCallTime := <-testFnCallTimeCh: + if i != 0 { + actualRetryDelay := nextCallTime.Sub(*prevCallTime) + require.GreaterOrEqual(t, actualRetryDelay, expectedRetryDelay) + } + + *prevCallTime = nextCallTime + default: + t.Fatalf( + "expected %d calls to testFn, but only received %d", + expectedRetryLimit, i+1, + ) + } + } + + // Verify the logged error messages + var ( + logOutputLines = strings.Split(strings.Trim(logOutput.String(), "\n"), "\n") + expectedPrefix = "ERROR: retrying TestOnError after error: test error" + ) + + select { + case err := <-retryOnErrorErrCh: + require.ErrorIs(t, err, testErr) + case <-time.After(100 * time.Millisecond): + t.Fatalf("expected error from OnError, but none received") + } + + require.Lenf( + t, logOutputLines, + expectedRetryLimit-1, + "expected %d log lines, got %d", + expectedRetryLimit-1, len(logOutputLines), + ) + for _, line := range logOutputLines { + require.Contains(t, line, expectedPrefix) + } +}