From e06081d10f3c7b63e24fe6a3dbe46836ba17b46d Mon Sep 17 00:00:00 2001 From: Hyeonho Kim Date: Thu, 14 Jul 2022 10:43:17 +0900 Subject: [PATCH] feat: add manual commit feature to kinesumer (#12) * feat(api): add commit feature to kinesumer * docs: add comment for error message * feat: update multiple checkpoints * feat: add kinesumer to mark record & manual commit * feat: divide commit & consuming, add offset config * refactor: change code style to origin * style: change offsetmanagement to commit properties * fix: hanging if error channel is full (#11) (#13) * docs: add caution message for error channel * style: change grammar issue * fix: error can be ignored if channel is blocked * style: change comments sendErrorOrIgnore * style: change sendErrorOrIgnore to sendOrDiscardError, remove docs * refactor: remove updating checkpoints with 25 items * test: add unit tests for markrecord&commit * style: make comments more clear * test: add state store test * refactor: rename commitCheckPointPerStream to commitCheckPointsPerStream * refactor: follow ok idiom * refactor: remove send error when empty commit checkpoints * doc: remove meaningless comments * refactor: remove duplicated consumed nil records check * refactor: rename commit properties to commit config * test: remove inappropriate test for commitCheckPointsPerStream * refactor: remove cleanup function and move to after end of consuming * refactor: make clean up offset using cleanUpOffset function * test: rename works properly to works fine * refactor: rename default commit properties to commit config --- go.mod | 2 + go.sum | 14 ++ kinesumer.go | 329 +++++++++++++++++++++++++++------------ kinesumer_test.go | 376 ++++++++++++++++++++++++++++++++++++++++++++- statemodel.go | 8 + statestore.go | 38 +++-- statestore_mock.go | 164 ++++++++++++++++++++ statestore_test.go | 108 +++++++++++++ syncclient.go | 3 + 9 files changed, 924 insertions(+), 118 deletions(-) create mode 100644 statestore_mock.go create mode 100644 statestore_test.go diff --git a/go.mod b/go.mod index 8907781..66c7357 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.16 require ( github.com/aws/aws-sdk-go v1.40.4 + github.com/golang/mock v1.6.0 github.com/guregu/dynamo v1.10.4 github.com/pkg/errors v0.9.1 github.com/rs/zerolog v1.23.0 + github.com/stretchr/testify v1.7.0 ) diff --git a/go.sum b/go.sum index 1bb9e29..a33bff0 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/guregu/dynamo v1.10.4 h1:okxTx3ibVXSO02tGEVDpe0x8oGvwwZnJ+tePtKTlpz0= github.com/guregu/dynamo v1.10.4/go.mod h1:h8dDh87mKIRfkSId4Qdk3PjAsNtRrldrNw72B/lHW0s= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -23,25 +25,33 @@ github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.23.0 h1:UskrK+saS9P9Y789yNNulYKdARjPZuS35B8gJF2x60g= github.com/rs/zerolog v1.23.0/go.mod h1:6c7hFfxPOy7TacJc4Fcdi24/J0NKYGzjG8FWRI916Qo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -50,9 +60,13 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/kinesumer.go b/kinesumer.go index abc1c64..4f948dd 100644 --- a/kinesumer.go +++ b/kinesumer.go @@ -23,8 +23,8 @@ const ( syncInterval = 5*time.Second + jitter syncTimeout = 5*time.Second - jitter - checkPointTimeout = 2 * time.Second - checkPointInterval = 5 * time.Second // For EFO mode. + defaultCommitTimeout = 2 * time.Second + defaultCommitInterval = 5 * time.Second defaultScanLimit int64 = 2000 @@ -34,6 +34,14 @@ const ( recordsChanBuffer = 20 ) +// Error codes. +var ( + ErrEmptySequenceNumber = errors.New("kinesumer: sequence number can't be empty") + ErrInvalidStream = errors.New("kinesumer: invalid stream") + errEmptyCommitCheckpoints = errors.New("kinesumer: commit checkpoints can't be empty") + errMarkNilRecord = errors.New("kinesumer: nil record can't be marked") +) + // Config defines configs for the Kinesumer client. type Config struct { App string // Application name. @@ -43,7 +51,6 @@ type Config struct { // Kinesis configs. KinesisRegion string KinesisEndpoint string // Only for local server. - EFOMode bool // On/off the Enhanced Fan-Out feature. // If you want to consume messages from Kinesis in a different account, // you need to set up the IAM role to access to target account, and pass the role arn here. // Reference: https://docs.aws.amazon.com/kinesisanalytics/latest/java/examples-cross.html. @@ -59,11 +66,38 @@ type Config struct { ScanLimit int64 ScanTimeout time.Duration ScanInterval time.Duration + + EFOMode bool // On/off the Enhanced Fan-Out feature. + + // This config is used for how to manage sequence number. + Commit *CommitConfig +} + +// CommitConfig holds options for how to offset handled. +type CommitConfig struct { + // Whether to auto-commit updated sequence number. (default is true) + Auto bool + + // How frequently to commit updated sequence numbers. (default is 5s) + Interval time.Duration + + // A Timeout config for commit per stream. (default is 2s) + Timeout time.Duration +} + +// NewDefaultCommitConfig returns a new default offset management configuration. +func NewDefaultCommitConfig() *CommitConfig { + return &CommitConfig{ + Auto: true, + Interval: defaultCommitInterval, + Timeout: defaultCommitTimeout, + } } // Record represents kinesis.Record with stream name. type Record struct { - Stream string + Stream string + ShardID string *kinesis.Record } @@ -123,6 +157,8 @@ type Kinesumer struct { shards map[string]Shards // To cache the last sequence numbers for each shard. checkPoints map[string]*sync.Map + // offsets holds uncommitted sequence numbers. + offsets map[string]*sync.Map // To manage the next shard iterators for each shard. nextIters map[string]*sync.Map @@ -135,6 +171,11 @@ type Kinesumer struct { started chan struct{} + // commit options. + autoCommit bool + commitTimeout time.Duration + commitInterval time.Duration + // To wait the running consumer loops when stopping. wait sync.WaitGroup stop chan struct{} @@ -196,6 +237,10 @@ func NewKinesumer(cfg *Config) (*Kinesumer, error) { ) } + if cfg.Commit == nil { + cfg.Commit = NewDefaultCommitConfig() + } + buffer := recordsChanBuffer kinesumer := &Kinesumer{ id: id, @@ -209,6 +254,7 @@ func NewKinesumer(cfg *Config) (*Kinesumer, error) { shardCaches: make(map[string][]string), shards: make(map[string]Shards), checkPoints: make(map[string]*sync.Map), + offsets: make(map[string]*sync.Map), nextIters: make(map[string]*sync.Map), scanLimit: defaultScanLimit, scanTimeout: defaultScanTimeout, @@ -234,6 +280,10 @@ func NewKinesumer(cfg *Config) (*Kinesumer, error) { kinesumer.efoMeta = make(map[string]*efoMeta) } + kinesumer.autoCommit = cfg.Commit.Auto + kinesumer.commitInterval = cfg.Commit.Interval + kinesumer.commitTimeout = cfg.Commit.Timeout + if err := kinesumer.init(); err != nil { return nil, errors.WithStack(err) } @@ -422,6 +472,9 @@ func (k *Kinesumer) start() { k.consumePolling() } + if k.autoCommit { + go k.commitPeriodically() + } } func (k *Kinesumer) pause() { @@ -450,65 +503,14 @@ func (k *Kinesumer) consumePipe(stream string, shard *Shard) { streamEvents := make(chan kinesis.SubscribeToShardEventStreamEvent) - go func() { - defer close(streamEvents) - - for { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - - input := &kinesis.SubscribeToShardInput{ - ConsumerARN: aws.String(k.efoMeta[stream].consumerARN), - ShardId: aws.String(shard.ID), - StartingPosition: &kinesis.StartingPosition{ - Type: aws.String(kinesis.ShardIteratorTypeLatest), - }, - } - - if seq, ok := k.checkPoints[stream].Load(shard.ID); ok { - input.StartingPosition.SetType(kinesis.ShardIteratorTypeAfterSequenceNumber) - input.StartingPosition.SetSequenceNumber(seq.(string)) - } - - output, err := k.client.SubscribeToShardWithContext(ctx, input) - if err != nil { - k.sendOrDiscardError(errors.WithStack(err)) - cancel() - continue - } - - open := true - for open { - select { - case <-k.stop: - output.GetEventStream().Close() - cancel() - return - case <-k.close: - output.GetEventStream().Close() - cancel() - return - case e, ok := <-output.GetEventStream().Events(): - if !ok { - cancel() - open = false - } - streamEvents <- e - } - } - } - }() - - var ( - lastSequence string - checkPointTicker = time.NewTicker(checkPointInterval) - ) + go k.subscribeToShard(streamEvents, stream, shard) for { select { case e, ok := <-streamEvents: if !ok { - k.commitCheckPoint(stream, shard.ID, lastSequence) + k.Commit() + k.cleanupOffsets(stream, shard) return } if se, ok := e.(*kinesis.SubscribeToShardEvent); ok { @@ -518,40 +520,69 @@ func (k *Kinesumer) consumePipe(stream string, shard *Shard) { } for i, record := range se.Records { - k.records <- &Record{ - Stream: stream, - Record: record, + r := &Record{ + Stream: stream, + ShardID: shard.ID, + Record: record, } - if i == n-1 { - lastSequence = *record.SequenceNumber + k.records <- r + + if k.autoCommit && i == n-1 { + k.MarkRecord(r) } } } - case <-checkPointTicker.C: - k.commitCheckPoint(stream, shard.ID, lastSequence) } } } -// update checkpoint using sequence number. -func (k *Kinesumer) commitCheckPoint(stream, shardID, lastSeqNum string) { - if lastSeqNum == "" { - // sequence number can't be empty. - return - } +func (k *Kinesumer) subscribeToShard(streamEvents chan kinesis.SubscribeToShardEventStreamEvent, stream string, shard *Shard) { + defer close(streamEvents) - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, checkPointTimeout) - defer cancel() + for { + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + + input := &kinesis.SubscribeToShardInput{ + ConsumerARN: aws.String(k.efoMeta[stream].consumerARN), + ShardId: aws.String(shard.ID), + StartingPosition: &kinesis.StartingPosition{ + Type: aws.String(kinesis.ShardIteratorTypeLatest), + }, + } + + if seq, ok := k.checkPoints[stream].Load(shard.ID); ok { + input.StartingPosition.SetType(kinesis.ShardIteratorTypeAfterSequenceNumber) + input.StartingPosition.SetSequenceNumber(seq.(string)) + } + + output, err := k.client.SubscribeToShardWithContext(ctx, input) + if err != nil { + k.sendOrDiscardError(errors.WithStack(err)) + cancel() + continue + } - if err := k.stateStore.UpdateCheckPoint(ctx, stream, shardID, lastSeqNum); err != nil { - log.Err(err). - Str("stream", stream). - Str("shard id", shardID). - Str("missed sequence number", lastSeqNum). - Msg("kinesumer: failed to UpdateCheckPoint") + open := true + for open { + select { + case <-k.stop: + output.GetEventStream().Close() + cancel() + return + case <-k.close: + output.GetEventStream().Close() + cancel() + return + case e, ok := <-output.GetEventStream().Events(): + if !ok { + cancel() + open = false + } + streamEvents <- e + } + } } - k.checkPoints[stream].Store(shardID, lastSeqNum) } /* @@ -575,20 +606,42 @@ func (k *Kinesumer) consumeLoop(stream string, shard *Shard) { for { select { case <-k.stop: + k.Commit() return case <-k.close: + k.Commit() return default: time.Sleep(k.scanInterval) - if closed := k.consumeOnce(stream, shard); closed { + records, closed := k.consumeOnce(stream, shard) + if closed { + k.cleanupOffsets(stream, shard) return // Close consume loop if shard is CLOSED and has no data. } + + n := len(records) + if n == 0 { + continue + } + + for i, record := range records { + r := &Record{ + Stream: stream, + ShardID: shard.ID, + Record: record, + } + k.records <- r + + if k.autoCommit && i == n-1 { + k.MarkRecord(r) + } + } } } } -// It returns a flag whether if shard is CLOSED state and has no remaining data. -func (k *Kinesumer) consumeOnce(stream string, shard *Shard) bool { +// It returns records & flag which is whether if shard is CLOSED state and has no remaining data. +func (k *Kinesumer) consumeOnce(stream string, shard *Shard) ([]*kinesis.Record, bool) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, k.scanTimeout) defer cancel() @@ -598,7 +651,7 @@ func (k *Kinesumer) consumeOnce(stream string, shard *Shard) bool { k.sendOrDiscardError(errors.WithStack(err)) var riue *kinesis.ResourceInUseException - return errors.As(err, &riue) + return nil, errors.As(err, &riue) } output, err := k.client.GetRecordsWithContext(ctx, &kinesis.GetRecordsInput{ @@ -610,35 +663,23 @@ func (k *Kinesumer) consumeOnce(stream string, shard *Shard) bool { var riue *kinesis.ResourceInUseException if errors.As(err, &riue) { - return true + return nil, true } var eie *kinesis.ExpiredIteratorException if errors.As(err, &eie) { k.nextIters[stream].Delete(shard.ID) // Delete expired next iterator cache. } - return false + return nil, false } defer k.nextIters[stream].Store(shard.ID, output.NextShardIterator) // Update iter. n := len(output.Records) // We no longer care about shards that have no records left and are in the "CLOSED" state. if n == 0 { - return shard.Closed - } - - var lastSequence string - for i, record := range output.Records { - k.records <- &Record{ - Stream: stream, - Record: record, - } - if i == n-1 { - lastSequence = *record.SequenceNumber - } + return nil, shard.Closed } - k.commitCheckPoint(stream, shard.ID, lastSequence) - return false + return output.Records, false } func (k *Kinesumer) getNextShardIterator( @@ -666,6 +707,92 @@ func (k *Kinesumer) getNextShardIterator( return output.ShardIterator, nil } +func (k *Kinesumer) commitPeriodically() { + var checkPointTicker = time.NewTicker(k.commitInterval) + + for { + select { + case <-k.stop: + return + case <-k.close: + return + case <-checkPointTicker.C: + k.Commit() + } + } +} + +// MarkRecord marks the provided record as consumed. +func (k *Kinesumer) MarkRecord(record *Record) { + if record == nil { + k.sendOrDiscardError(errMarkNilRecord) + return + } + + seqNum := *record.SequenceNumber + if seqNum == "" { + // sequence number can't be empty. + k.sendOrDiscardError(ErrEmptySequenceNumber) + return + } + if _, ok := k.checkPoints[record.Stream]; !ok { + k.sendOrDiscardError(ErrInvalidStream) + return + } + k.offsets[record.Stream].Store(record.ShardID, seqNum) +} + +// Commit updates check point using current checkpoints. +func (k *Kinesumer) Commit() { + var wg sync.WaitGroup + for stream := range k.shards { + wg.Add(1) + + var checkpoints []*ShardCheckPoint + k.offsets[stream].Range(func(shardID, seqNum interface{}) bool { + checkpoints = append(checkpoints, &ShardCheckPoint{ + Stream: stream, + ShardID: shardID.(string), + SequenceNumber: seqNum.(string), + UpdatedAt: time.Now(), + }) + return true + }) + + go func(stream string, checkpoints []*ShardCheckPoint) { + defer wg.Done() + k.commitCheckPointsPerStream(stream, checkpoints) + }(stream, checkpoints) + } + wg.Wait() +} + +// commitCheckPointsPerStream updates checkpoints using sequence number. +func (k *Kinesumer) commitCheckPointsPerStream(stream string, checkpoints []*ShardCheckPoint) { + if len(checkpoints) == 0 { + return + } + + timeoutCtx, cancel := context.WithTimeout(context.Background(), k.commitTimeout) + defer cancel() + + if err := k.stateStore.UpdateCheckPoints(timeoutCtx, checkpoints); err != nil { + k.sendOrDiscardError(errors.Wrapf(err, "failed to commit on stream: %s", stream)) + return + } +} + +// cleanupOffsets remove uninterested stream's shard. +// TODO(proost): how to remove unused stream? +func (k *Kinesumer) cleanupOffsets(stream string, shard *Shard) { + if shard == nil { + return + } + if offsets, ok := k.offsets[stream]; ok { + offsets.Delete(shard.ID) + } +} + // Refresh refreshes the consuming streams. func (k *Kinesumer) Refresh(streams []string) { k.mu.Lock() diff --git a/kinesumer_test.go b/kinesumer_test.go index 5a525f5..4c264ad 100644 --- a/kinesumer_test.go +++ b/kinesumer_test.go @@ -1,15 +1,22 @@ package kinesumer import ( + "context" "sort" + "sync" "testing" "time" + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/kinesis" - "github.com/daangn/kinesumer/pkg/collection" "github.com/guregu/dynamo" + "github.com/stretchr/testify/assert" + + "github.com/daangn/kinesumer/pkg/collection" ) type testEnv struct { @@ -308,3 +315,370 @@ func TestShardsRebalancing(t *testing.T) { } } } + +func TestKinesumer_MarkRecordWorksFine(t *testing.T) { + env := newTestEnv(t) + defer env.cleanUp(t) + + streams := []string{"events"} + _, err := env.client1.Consume(streams) + if err != nil { + t.Errorf("expected no errors, got %v", err) + } + + expectedSeqNum := "12345" + shardIDs := env.client1.shards["events"].ids() + for _, shardID := range shardIDs { + env.client1.MarkRecord(&Record{ + Stream: "events", + ShardID: shardID, + Record: &kinesis.Record{ + SequenceNumber: &expectedSeqNum, + }, + }) + } + + for _, shardID := range shardIDs { + resultSeqNum, ok := env.client1.offsets["events"].Load(shardID) + if ok { + assert.EqualValues(t, expectedSeqNum, resultSeqNum, "they should be equal") + } else { + t.Errorf("expected %v, got %v", expectedSeqNum, resultSeqNum) + } + } +} + +func TestKinesumer_MarkRecordFails(t *testing.T) { + testCases := []struct { + name string + kinesumer *Kinesumer + input *Record + wantErr error + }{ + { + name: "when input record is nil", + kinesumer: &Kinesumer{ + errors: make(chan error, 1), + }, + input: nil, + wantErr: errMarkNilRecord, + }, + { + name: "when record sequence number is empty", + kinesumer: &Kinesumer{ + errors: make(chan error, 1), + }, + input: &Record{ + Stream: "foobar", + ShardID: "shardId-000", + Record: &kinesis.Record{ + SequenceNumber: func() *string { + emptyString := "" + return &emptyString + }(), + }, + }, + wantErr: ErrEmptySequenceNumber, + }, + { + name: "when unknown stream is given", + kinesumer: &Kinesumer{ + errors: make(chan error, 1), + checkPoints: map[string]*sync.Map{ + "foobar": {}, + }, + }, + input: &Record{ + Stream: "foo", + ShardID: "shardId-000", + Record: &kinesis.Record{ + SequenceNumber: func() *string { + seq := "0" + return &seq + }(), + }, + }, + wantErr: ErrInvalidStream, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + kinesumer := tc.kinesumer + kinesumer.MarkRecord(tc.input) + + result := <-kinesumer.errors + assert.ErrorIs(t, result, tc.wantErr, "there should be an expected error") + }) + } +} + +func TestKinesumer_Commit(t *testing.T) { + env := newTestEnv(t) + defer env.cleanUp(t) + + streams := []string{"events"} + _, err := env.client1.Consume(streams) + if err != nil { + t.Errorf("expected no errors, got %v", err) + } + _, err = env.client2.Consume(streams) + if err != nil { + t.Errorf("expected no errors, got %v", err) + } + _, err = env.client3.Consume(streams) + if err != nil { + t.Errorf("expected no errors, got %v", err) + } + + clients := map[string]*Kinesumer{ + env.client1.id: env.client1, + env.client2.id: env.client2, + env.client3.id: env.client3, + } + + expectedSeqNum := "12345" + for _, client := range clients { + shardIDs := client.shards["events"].ids() + for _, shardID := range shardIDs { + env.client1.MarkRecord(&Record{ + Stream: "events", + ShardID: shardID, + Record: &kinesis.Record{ + SequenceNumber: &expectedSeqNum, + }, + }) + } + } + + for _, client := range clients { + client.Commit() + } + + for _, client := range clients { + shardIDs := client.shards["events"].ids() + checkpoints, _ := client.stateStore.ListCheckPoints(context.Background(), "events", shardIDs) + for _, checkpoint := range checkpoints { + assert.EqualValues(t, expectedSeqNum, checkpoint, "sequence number should be equal") + } + } +} + +func TestKinesumer_commitCheckPointPerStreamWorksFine(t *testing.T) { + ctrl := gomock.NewController(t) + + input := []*ShardCheckPoint{ + { + Stream: "foobar", + ShardID: "shardId-0", + SequenceNumber: "0", + }, + } + + mockStateStore := NewMockStateStore(ctrl) + mockStateStore.EXPECT(). + UpdateCheckPoints(gomock.Any(), input). + Times(1). + Return(nil) + + offsets := map[string]*sync.Map{} + offsets["foobar"] = &sync.Map{} + offsets["foobar"].Store("shardId-0", "0") + offsets["foobar"].Store("shardId-1", "1") + kinesumer := &Kinesumer{ + offsets: offsets, + commitTimeout: 2 * time.Second, + stateStore: mockStateStore, + } + + kinesumer.commitCheckPointsPerStream("foobar", input) + + select { + case err := <-kinesumer.Errors(): + assert.NoError(t, err, "there should be no error") + default: + } +} + +func TestKinesumer_commitCheckPointPerStreamFails(t *testing.T) { + ctrl := gomock.NewController(t) + + testCases := []struct { + name string + newKinesumer func() *Kinesumer + input struct { + stream string + checkpoints []*ShardCheckPoint + } + wantErrMsg string + }{ + { + name: "when state store fails to update checkpoints", + newKinesumer: func() *Kinesumer { + mockStateStore := NewMockStateStore(ctrl) + mockStateStore.EXPECT(). + UpdateCheckPoints(gomock.Any(), gomock.Any()). + Times(1). + Return(errors.New("mock error")) + return &Kinesumer{ + errors: make(chan error, 1), + stateStore: mockStateStore, + } + }, + input: struct { + stream string + checkpoints []*ShardCheckPoint + }{ + stream: "foobar", + checkpoints: []*ShardCheckPoint{ + { + Stream: "foobar", + ShardID: "shardId-000", + SequenceNumber: "0", + }, + }, + }, + wantErrMsg: "failed to commit on stream: foobar: mock error", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + kinesumer := tc.newKinesumer() + kinesumer.commitCheckPointsPerStream(tc.input.stream, tc.input.checkpoints) + result := <-kinesumer.errors + assert.EqualError(t, result, tc.wantErrMsg, "there should be an expected error") + }) + } +} + +func TestKinesumer_cleanupOffsetsWorksFine(t *testing.T) { + testCases := []struct { + name string + newKinesumer func() *Kinesumer + input struct { + stream string + shard *Shard + } + want map[string]map[string]string // stream - shard - sequence number + }{ + { + name: "when clean up existing offset", + newKinesumer: func() *Kinesumer { + offsets := map[string]*sync.Map{} + + offsets["foobar"] = &sync.Map{} + offsets["foobar"].Store("shardId-0", "0") + + offsets["foo"] = &sync.Map{} + offsets["foo"].Store("shardId-1", "1") + return &Kinesumer{ + offsets: offsets, + } + }, + input: struct { + stream string + shard *Shard + }{ + stream: "foobar", + shard: &Shard{ + ID: "shardId-0", + }, + }, + want: map[string]map[string]string{ + "foo": { + "shardId-1": "1", + }, + }, + }, + { + name: "when clean up non-existent stream", + newKinesumer: func() *Kinesumer { + offsets := map[string]*sync.Map{} + + offsets["foobar"] = &sync.Map{} + offsets["foobar"].Store("shardId-0", "10") + + offsets["foo"] = &sync.Map{} + offsets["foo"].Store("shardId-1", "20") + return &Kinesumer{ + offsets: offsets, + } + }, + input: struct { + stream string + shard *Shard + }{ + stream: "bar", + shard: &Shard{ + ID: "shardId-2", + }, + }, + want: map[string]map[string]string{ + "foobar": { + "shardId-0": "10", + }, + "foo": { + "shardId-1": "20", + }, + }, + }, + { + name: "when clean up non-existent shard", + newKinesumer: func() *Kinesumer { + offsets := map[string]*sync.Map{} + + offsets["foobar"] = &sync.Map{} + offsets["foobar"].Store("shardId-0", "10") + + offsets["foo"] = &sync.Map{} + offsets["foo"].Store("shardId-1", "20") + return &Kinesumer{ + offsets: offsets, + } + }, + input: struct { + stream string + shard *Shard + }{ + stream: "foo", + shard: &Shard{ + ID: "shardId-2", + }, + }, + want: map[string]map[string]string{ + "foobar": { + "shardId-0": "10", + }, + "foo": { + "shardId-1": "20", + }, + }, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + kinesumer := tc.newKinesumer() + kinesumer.cleanupOffsets(tc.input.stream, tc.input.shard) + + result := make(map[string]map[string]string) + for stream, offsets := range kinesumer.offsets { + streamResult := make(map[string]string) + offsets.Range(func(shardID, sequence interface{}) bool { + streamResult[shardID.(string)] = sequence.(string) + return true + }) + result[stream] = streamResult + } + + for stream, expectedInStream := range tc.want { + if assert.NotEmpty(t, result[stream]) { + streamResult := result[stream] + for expectedShardID, expectedSeqNum := range expectedInStream { + if assert.NotEmpty(t, streamResult[expectedShardID]) { + assert.EqualValues(t, streamResult[expectedShardID], expectedSeqNum, "they should be equal") + } + } + } + } + }) + } +} diff --git a/statemodel.go b/statemodel.go index cfa9612..fee2de7 100644 --- a/statemodel.go +++ b/statemodel.go @@ -36,6 +36,14 @@ func buildClientKey(app string) string { return buildKeyFn(clientKeyFmt, app) } +// ShardCheckPoint manages a shard check point. +type ShardCheckPoint struct { + Stream string + ShardID string + SequenceNumber string + UpdatedAt time.Time +} + // stateCheckPoint manages record check points. type stateCheckPoint struct { StreamKey string `dynamo:"pk,pk"` diff --git a/statestore.go b/statestore.go index 07d76e5..38225d7 100644 --- a/statestore.go +++ b/statestore.go @@ -10,6 +10,8 @@ import ( "github.com/pkg/errors" ) +//go:generate mockgen -source=statestore.go -destination=statestore_mock.go -package=kinesumer StateStore + // Error codes. var ( ErrNoShardCache = errors.New("kinesumer: shard cache not found") @@ -27,7 +29,7 @@ type ( PingClientAliveness(ctx context.Context, clientID string) error PruneClients(ctx context.Context) error ListCheckPoints(ctx context.Context, stream string, shardIDs []string) (map[string]string, error) - UpdateCheckPoint(ctx context.Context, stream, shardID, seq string) error + UpdateCheckPoints(ctx context.Context, checkpoints []*ShardCheckPoint) error } db struct { @@ -233,7 +235,7 @@ func (s *stateStore) ListCheckPoints( for _, id := range shardIDs { keys = append( keys, - dynamo.Keys{buildCheckPointKey(s.app, id), stream}, + dynamo.Keys{buildCheckPointKey(s.app, stream), id}, ) } @@ -254,21 +256,25 @@ func (s *stateStore) ListCheckPoints( return seqMap, nil } -// UpdateCheckPoint updates the check point sequence number for a shard. -func (s *stateStore) UpdateCheckPoint( - ctx context.Context, stream, shardID, seq string, -) error { - var ( - key = buildCheckPointKey(s.app, stream) - now = time.Now() - ) - checkPoint := stateCheckPoint{ - StreamKey: key, - ShardID: shardID, - SequenceNumber: seq, - LastUpdate: now, +// UpdateCheckPoints updates the check point sequence numbers for multiple shards. +func (s *stateStore) UpdateCheckPoints(ctx context.Context, checkpoints []*ShardCheckPoint) error { + stateCheckPoints := make([]interface{}, len(checkpoints)) + for i, checkpoint := range checkpoints { + stateCheckPoints[i] = stateCheckPoint{ + StreamKey: buildCheckPointKey(s.app, checkpoint.Stream), + ShardID: checkpoint.ShardID, + SequenceNumber: checkpoint.SequenceNumber, + LastUpdate: checkpoint.UpdatedAt, + } } - if err := s.db.table.Put(checkPoint).RunWithContext(ctx); err != nil { + + // TODO(proost): check written bytes + _, err := s.db.table. + Batch("pk", "sk"). + Write(). + Put(stateCheckPoints...). + RunWithContext(ctx) + if err != nil { return errors.WithStack(err) } return nil diff --git a/statestore_mock.go b/statestore_mock.go new file mode 100644 index 0000000..9fa9ada --- /dev/null +++ b/statestore_mock.go @@ -0,0 +1,164 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: statestore.go + +// Package kinesumer is a generated GoMock package. +package kinesumer + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockStateStore is a mock of StateStore interface. +type MockStateStore struct { + ctrl *gomock.Controller + recorder *MockStateStoreMockRecorder +} + +// MockStateStoreMockRecorder is the mock recorder for MockStateStore. +type MockStateStoreMockRecorder struct { + mock *MockStateStore +} + +// NewMockStateStore creates a new mock instance. +func NewMockStateStore(ctrl *gomock.Controller) *MockStateStore { + mock := &MockStateStore{ctrl: ctrl} + mock.recorder = &MockStateStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStateStore) EXPECT() *MockStateStoreMockRecorder { + return m.recorder +} + +// DeregisterClient mocks base method. +func (m *MockStateStore) DeregisterClient(ctx context.Context, clientID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeregisterClient", ctx, clientID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeregisterClient indicates an expected call of DeregisterClient. +func (mr *MockStateStoreMockRecorder) DeregisterClient(ctx, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeregisterClient", reflect.TypeOf((*MockStateStore)(nil).DeregisterClient), ctx, clientID) +} + +// GetShards mocks base method. +func (m *MockStateStore) GetShards(ctx context.Context, stream string) (Shards, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetShards", ctx, stream) + ret0, _ := ret[0].(Shards) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetShards indicates an expected call of GetShards. +func (mr *MockStateStoreMockRecorder) GetShards(ctx, stream interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShards", reflect.TypeOf((*MockStateStore)(nil).GetShards), ctx, stream) +} + +// ListAllAliveClientIDs mocks base method. +func (m *MockStateStore) ListAllAliveClientIDs(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAllAliveClientIDs", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAllAliveClientIDs indicates an expected call of ListAllAliveClientIDs. +func (mr *MockStateStoreMockRecorder) ListAllAliveClientIDs(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllAliveClientIDs", reflect.TypeOf((*MockStateStore)(nil).ListAllAliveClientIDs), ctx) +} + +// ListCheckPoints mocks base method. +func (m *MockStateStore) ListCheckPoints(ctx context.Context, stream string, shardIDs []string) (map[string]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListCheckPoints", ctx, stream, shardIDs) + ret0, _ := ret[0].(map[string]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListCheckPoints indicates an expected call of ListCheckPoints. +func (mr *MockStateStoreMockRecorder) ListCheckPoints(ctx, stream, shardIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListCheckPoints", reflect.TypeOf((*MockStateStore)(nil).ListCheckPoints), ctx, stream, shardIDs) +} + +// PingClientAliveness mocks base method. +func (m *MockStateStore) PingClientAliveness(ctx context.Context, clientID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PingClientAliveness", ctx, clientID) + ret0, _ := ret[0].(error) + return ret0 +} + +// PingClientAliveness indicates an expected call of PingClientAliveness. +func (mr *MockStateStoreMockRecorder) PingClientAliveness(ctx, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PingClientAliveness", reflect.TypeOf((*MockStateStore)(nil).PingClientAliveness), ctx, clientID) +} + +// PruneClients mocks base method. +func (m *MockStateStore) PruneClients(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PruneClients", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// PruneClients indicates an expected call of PruneClients. +func (mr *MockStateStoreMockRecorder) PruneClients(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PruneClients", reflect.TypeOf((*MockStateStore)(nil).PruneClients), ctx) +} + +// RegisterClient mocks base method. +func (m *MockStateStore) RegisterClient(ctx context.Context, clientID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterClient", ctx, clientID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterClient indicates an expected call of RegisterClient. +func (mr *MockStateStoreMockRecorder) RegisterClient(ctx, clientID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterClient", reflect.TypeOf((*MockStateStore)(nil).RegisterClient), ctx, clientID) +} + +// UpdateCheckPoints mocks base method. +func (m *MockStateStore) UpdateCheckPoints(ctx context.Context, checkpoints []*ShardCheckPoint) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateCheckPoints", ctx, checkpoints) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateCheckPoints indicates an expected call of UpdateCheckPoints. +func (mr *MockStateStoreMockRecorder) UpdateCheckPoints(ctx, checkpoints interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCheckPoints", reflect.TypeOf((*MockStateStore)(nil).UpdateCheckPoints), ctx, checkpoints) +} + +// UpdateShards mocks base method. +func (m *MockStateStore) UpdateShards(ctx context.Context, stream string, shards Shards) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateShards", ctx, stream, shards) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateShards indicates an expected call of UpdateShards. +func (mr *MockStateStoreMockRecorder) UpdateShards(ctx, stream, shards interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateShards", reflect.TypeOf((*MockStateStore)(nil).UpdateShards), ctx, stream, shards) +} diff --git a/statestore_test.go b/statestore_test.go new file mode 100644 index 0000000..80f93c1 --- /dev/null +++ b/statestore_test.go @@ -0,0 +1,108 @@ +package kinesumer + +import ( + "context" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/guregu/dynamo" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func newTestDynamoDB(t *testing.T) *dynamo.DB { + awsCfg := aws.NewConfig() + awsCfg.WithRegion("ap-northeast-2") + awsCfg.WithEndpoint("http://localhost:14566") + sess, err := session.NewSession(awsCfg) + if err != nil { + t.Fatal("failed to init test env:", err.Error()) + } + return dynamo.New(sess) +} + +func cleanUpStateStore(t *testing.T, store *stateStore) { + type PkSk struct { + PK string `dynamo:"pk"` + SK string `dynamo:"sk"` + } + + var ( + pksks []*PkSk + keys []dynamo.Keyed + ) + table := store.db.table + if err := table.Scan().All(&pksks); err != nil { + t.Fatal("failed to scan the state table:", err.Error()) + } + for _, pksk := range pksks { + keys = append(keys, &dynamo.Keys{pksk.PK, pksk.SK}) + } + if _, err := table. + Batch("pk", "sk"). + Write(). + Delete(keys...). + Run(); err != nil { + t.Fatal("failed to delete all test data:", err.Error()) + } +} + +func TestStateStore_UpdateCheckPointsWorksFine(t *testing.T) { + cfg := &Config{ + App: "test", + DynamoDBRegion: "ap-northeast-2", + DynamoDBTable: "kinesumer-state-store", + DynamoDBEndpoint: "http://localhost:14566", + } + store, err := newStateStore(cfg) + assert.NoError(t, err, "there should be no error") + + s, _ := store.(*stateStore) + defer cleanUpStateStore(t, s) + + expectedUpdatedAt := time.Date(2022, 7, 12, 12, 35, 0, 0, time.UTC) + + expected := []*stateCheckPoint{ + { + StreamKey: buildCheckPointKey("test", "foobar"), + ShardID: "shardId-000", + SequenceNumber: "0", + LastUpdate: expectedUpdatedAt, + }, + } + + err = s.UpdateCheckPoints( + context.Background(), + []*ShardCheckPoint{ + { + Stream: "foobar", + ShardID: "shardId-000", + SequenceNumber: "0", + UpdatedAt: expectedUpdatedAt, + }, + }, + ) + if assert.NoError(t, err, "there should be no error") { + assert.Eventually( + t, + func() bool { + var result []*stateCheckPoint + err := s.db.table. + Batch("pk", "sk"). + Get( + []dynamo.Keyed{ + dynamo.Keys{buildCheckPointKey("test", "foobar"), "shardId-000"}, + dynamo.Keys{buildCheckPointKey("test", "foo"), "shardId-001"}, + }..., + ).All(&result) + if assert.NoError(t, err) { + return assert.EqualValues(t, expected, result) + } + return false + }, + 600*time.Millisecond, + 100*time.Millisecond, + "they should be equal", + ) + } +} diff --git a/syncclient.go b/syncclient.go index 311b2b7..5e5457f 100644 --- a/syncclient.go +++ b/syncclient.go @@ -151,6 +151,9 @@ func (k *Kinesumer) syncShardInfoForStream( if _, ok := k.checkPoints[stream]; !ok { k.checkPoints[stream] = &sync.Map{} } + if _, ok := k.offsets[stream]; !ok { + k.offsets[stream] = &sync.Map{} + } // Delete uninterested shard ids. k.checkPoints[stream].Range(func(key, _ interface{}) bool {