From 41e78784b0d9726c104d960a8f8e55688df9da26 Mon Sep 17 00:00:00 2001 From: samhith-kakarla Date: Fri, 12 Jul 2024 12:47:53 -0700 Subject: [PATCH] test: Unit tests code coverage improvement (#1781) Signed-off-by: Samhith Kakarla --- go.mod | 1 + go.sum | 1 + pkg/isb/stores/jetstream/reader_test.go | 7 +- pkg/isb/stores/jetstream/writer_test.go | 12 +- pkg/shared/clients/nats/client_pool_test.go | 62 ++++ pkg/shared/clients/nats/nats_client.go | 6 + pkg/shared/clients/nats/nats_client_test.go | 127 +++++++ pkg/shared/clients/nats/options_test.go | 33 ++ pkg/shared/clients/nats/test/client.go | 30 -- pkg/shared/clients/nats/test/server.go | 4 +- pkg/shared/clients/redis/options_test.go | 93 +++++ pkg/shared/clients/redis/redis_client_test.go | 310 ++++++++++++++++- pkg/shared/clients/redis/redis_reader_test.go | 324 ++++++++++++++++++ pkg/shared/idlehandler/idlehandler_test.go | 120 +++++++ .../idlehandler/source_idlehandler_test.go | 45 +++ pkg/shared/kvs/jetstream/kv_store_test.go | 183 +++++++++- pkg/shared/util/json_test.go | 29 ++ pkg/shared/util/kubeconfig_test.go | 38 ++ pkg/shared/util/sasl_config_test.go | 181 ++++++++++ pkg/shared/util/tls_config_test.go | 74 ++++ pkg/shared/util/volume_test.go | 39 +++ .../initializer/initializer_test.go | 5 +- .../synchronizer/synchronizer_test.go | 3 +- pkg/watermark/fetch/edge_fetcher_test.go | 9 +- pkg/watermark/publish/publisher_test.go | 3 +- 25 files changed, 1684 insertions(+), 55 deletions(-) create mode 100644 pkg/shared/clients/nats/client_pool_test.go create mode 100644 pkg/shared/clients/nats/nats_client_test.go create mode 100644 pkg/shared/clients/nats/options_test.go delete mode 100644 pkg/shared/clients/nats/test/client.go create mode 100644 pkg/shared/clients/redis/options_test.go create mode 100644 pkg/shared/clients/redis/redis_reader_test.go create mode 100644 pkg/shared/idlehandler/idlehandler_test.go create mode 100644 pkg/shared/util/kubeconfig_test.go create mode 100644 pkg/shared/util/tls_config_test.go diff --git a/go.mod b/go.mod index dc6be6aeab..2560e8dc81 100644 --- a/go.mod +++ b/go.mod @@ -179,6 +179,7 @@ require ( github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tidwall/gjson v1.14.4 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index 4751ad152b..eb4d23b5cf 100644 --- a/go.sum +++ b/go.sum @@ -585,6 +585,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20161117074351-18a02ba4a312/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/pkg/isb/stores/jetstream/reader_test.go b/pkg/isb/stores/jetstream/reader_test.go index 5227760a36..cf35441b26 100644 --- a/pkg/isb/stores/jetstream/reader_test.go +++ b/pkg/isb/stores/jetstream/reader_test.go @@ -28,6 +28,7 @@ import ( "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/testutils" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" ) @@ -44,7 +45,7 @@ func TestJetStreamBufferRead(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) @@ -137,7 +138,7 @@ func TestGetName(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) defer defaultJetStreamClient.Close() @@ -162,7 +163,7 @@ func TestClose(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) diff --git a/pkg/isb/stores/jetstream/writer_test.go b/pkg/isb/stores/jetstream/writer_test.go index 665d641f49..8408f5b244 100644 --- a/pkg/isb/stores/jetstream/writer_test.go +++ b/pkg/isb/stores/jetstream/writer_test.go @@ -27,7 +27,9 @@ import ( "github.com/numaproj/numaflow/pkg/forwarder" "github.com/numaproj/numaflow/pkg/isb" "github.com/numaproj/numaflow/pkg/isb/testutils" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" + "github.com/numaproj/numaflow/pkg/udf/forward" "github.com/numaproj/numaflow/pkg/watermark/generic" "github.com/numaproj/numaflow/pkg/watermark/wmb" @@ -80,7 +82,7 @@ func TestForwarderJetStreamBuffer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) @@ -208,7 +210,7 @@ func TestJetStreamBufferWriterBufferFull(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) @@ -265,7 +267,7 @@ func TestJetStreamBufferWriterBufferFull_DiscardLatest(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) @@ -321,7 +323,7 @@ func TestWriteGetName(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) @@ -347,7 +349,7 @@ func TestWriteClose(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err) diff --git a/pkg/shared/clients/nats/client_pool_test.go b/pkg/shared/clients/nats/client_pool_test.go new file mode 100644 index 0000000000..8e8be47247 --- /dev/null +++ b/pkg/shared/clients/nats/client_pool_test.go @@ -0,0 +1,62 @@ +package nats + +import ( + "context" + "os" + "testing" + + dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" + "github.com/stretchr/testify/assert" +) + +func TestNewClientPool_Success(t *testing.T) { + os.Setenv(dfv1.EnvISBSvcJetStreamURL, "nats://localhost:4222") + os.Setenv(dfv1.EnvISBSvcJetStreamUser, "user") + os.Setenv(dfv1.EnvISBSvcJetStreamPassword, "password") + ctx := context.Background() + pool, err := NewClientPool(ctx) + + assert.NoError(t, err) + assert.NotNil(t, pool) + assert.Equal(t, 3, pool.clients.Len()) // Check if the pool size matches the default clientPoolSize +} + +func TestClientPool_NextAvailableClient(t *testing.T) { + os.Setenv(dfv1.EnvISBSvcJetStreamURL, "nats://localhost:4222") + os.Setenv(dfv1.EnvISBSvcJetStreamUser, "user") + os.Setenv(dfv1.EnvISBSvcJetStreamPassword, "password") + ctx := context.Background() + pool, err := NewClientPool(ctx) + assert.NoError(t, err) + assert.NotNil(t, pool) + + client1 := pool.NextAvailableClient() + assert.NotNil(t, client1) + + client2 := pool.NextAvailableClient() + assert.NotNil(t, client2) + + client3 := pool.NextAvailableClient() + assert.NotNil(t, client3) +} + +func TestClientPool_CloseAll(t *testing.T) { + os.Setenv(dfv1.EnvISBSvcJetStreamURL, "nats://localhost:4222") + os.Setenv(dfv1.EnvISBSvcJetStreamUser, "user") + os.Setenv(dfv1.EnvISBSvcJetStreamPassword, "password") + ctx := context.Background() + pool, err := NewClientPool(ctx) + assert.NoError(t, err) + assert.NotNil(t, pool) + + for e := pool.clients.Front(); e != nil; e = e.Next() { + client := e.Value.(*Client) + assert.False(t, client.nc.IsClosed()) + } + + pool.CloseAll() + for e := pool.clients.Front(); e != nil; e = e.Next() { + client := e.Value.(*Client) + assert.True(t, client.nc.IsClosed()) + } +} diff --git a/pkg/shared/clients/nats/nats_client.go b/pkg/shared/clients/nats/nats_client.go index 9f6ccff8b8..67bd2029d5 100644 --- a/pkg/shared/clients/nats/nats_client.go +++ b/pkg/shared/clients/nats/nats_client.go @@ -25,6 +25,7 @@ import ( "testing" "time" + "github.com/nats-io/nats-server/v2/server" "github.com/nats-io/nats.go" "go.uber.org/zap" @@ -181,3 +182,8 @@ func NewTestClient(t *testing.T, url string) *Client { } return &Client{nc: nc} } + +// JetStreamClient is used to get a testing JetStream client instance +func NewTestClientWithServer(t *testing.T, s *server.Server) *Client { + return NewTestClient(t, s.ClientURL()) +} diff --git a/pkg/shared/clients/nats/nats_client_test.go b/pkg/shared/clients/nats/nats_client_test.go new file mode 100644 index 0000000000..7e81277f0e --- /dev/null +++ b/pkg/shared/clients/nats/nats_client_test.go @@ -0,0 +1,127 @@ +package nats + +import ( + "context" + "os" + "testing" + + "github.com/nats-io/nats.go" + dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" + "github.com/numaproj/numaflow/pkg/shared/logging" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + + natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" +) + +func TestNewNATSClient(t *testing.T) { + // Setting up environment variables for the test + os.Setenv(dfv1.EnvISBSvcJetStreamURL, "nats://localhost:4222") + os.Setenv(dfv1.EnvISBSvcJetStreamUser, "user") + os.Setenv(dfv1.EnvISBSvcJetStreamPassword, "password") + defer os.Clearenv() + + log := zap.NewNop().Sugar() + + ctx := logging.WithLogger(context.Background(), log) + + client, err := NewNATSClient(ctx) + assert.NoError(t, err) + assert.NotNil(t, client) + + // Cleanup + client.Close() +} + +func TestNewNATSClient_Failure(t *testing.T) { + // Simulating environment variable absence + os.Clearenv() + + log := zap.NewNop().Sugar() + ctx := logging.WithLogger(context.Background(), log) + + client, err := NewNATSClient(ctx) + assert.Error(t, err) + assert.Nil(t, client) +} + +func TestSubscribe(t *testing.T) { + s := natstest.RunJetStreamServer(t) + defer s.Shutdown() + + client := NewTestClient(t, s.ClientURL()) + defer client.Close() + + // Create a stream + js, err := client.nc.JetStream() + assert.NoError(t, err) + _, err = js.AddStream(&nats.StreamConfig{ + Name: "TEST_STREAM", + Subjects: []string{"test.subject"}, + }) + assert.NoError(t, err) + + // Subscribe to a subject + sub, err := client.Subscribe("test.subject", "TEST_STREAM") + assert.NoError(t, err) + assert.NotNil(t, sub) + + // Test failure case: Invalid stream + _, err = client.Subscribe("balh", "INVALID_STREAM") + assert.Error(t, err) +} + +func TestBindKVStore(t *testing.T) { + s := natstest.RunJetStreamServer(t) + defer s.Shutdown() + + client := NewTestClient(t, s.ClientURL()) + defer client.Close() + + // Create a KeyValue store + js, err := client.nc.JetStream() + assert.NoError(t, err) + _, err = js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: "KV_TEST", + }) + assert.NoError(t, err) + + // Bind to the KeyValue store + kvStore, err := client.BindKVStore("KV_TEST") + assert.NoError(t, err) + assert.NotNil(t, kvStore) + + // Test failure case: Invalid KeyValue store + _, err = client.BindKVStore("INVALID_KV") + assert.Error(t, err) +} + +func TestJetStreamContext(t *testing.T) { + s := natstest.RunJetStreamServer(t) + defer s.Shutdown() + + client := NewTestClient(t, s.ClientURL()) + defer client.Close() + + jsCtx, err := client.JetStreamContext() + assert.NoError(t, err) + assert.NotNil(t, jsCtx) +} + +func TestNewTestClient(t *testing.T) { + s := natstest.RunJetStreamServer(t) + defer s.Shutdown() + + client := NewTestClient(t, s.ClientURL()) + assert.NotNil(t, client) + defer client.Close() +} + +func TestClose(t *testing.T) { + s := natstest.RunJetStreamServer(t) + defer s.Shutdown() + + client := NewTestClient(t, s.ClientURL()) + assert.NotNil(t, client) + client.Close() +} diff --git a/pkg/shared/clients/nats/options_test.go b/pkg/shared/clients/nats/options_test.go new file mode 100644 index 0000000000..87aa3f1858 --- /dev/null +++ b/pkg/shared/clients/nats/options_test.go @@ -0,0 +1,33 @@ +package nats + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultOptions(t *testing.T) { + opts := defaultOptions() + assert.NotNil(t, opts) + assert.Equal(t, 3, opts.clientPoolSize, "default client pool size should be 3") +} + +func TestWithClientPoolSize(t *testing.T) { + opts := defaultOptions() + assert.Equal(t, 3, opts.clientPoolSize, "default client pool size should be 3") + + option := WithClientPoolSize(10) + option(opts) + + assert.Equal(t, 10, opts.clientPoolSize, "client pool size should be set to 10") +} + +func TestCombinedOptions(t *testing.T) { + opts := defaultOptions() + assert.Equal(t, 3, opts.clientPoolSize, "default client pool size should be 3") + + option1 := WithClientPoolSize(5) + option1(opts) + + assert.Equal(t, 5, opts.clientPoolSize, "client pool size should be set to 5") +} diff --git a/pkg/shared/clients/nats/test/client.go b/pkg/shared/clients/nats/test/client.go deleted file mode 100644 index 01bfef5aca..0000000000 --- a/pkg/shared/clients/nats/test/client.go +++ /dev/null @@ -1,30 +0,0 @@ -/* -Copyright 2022 The Numaproj Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package test - -import ( - "testing" - - "github.com/nats-io/nats-server/v2/server" - - "github.com/numaproj/numaflow/pkg/shared/clients/nats" -) - -// JetStreamClient is used to get a testing JetStream client instance -func JetStreamClient(t *testing.T, s *server.Server) *nats.Client { - return nats.NewTestClient(t, s.ClientURL()) -} diff --git a/pkg/shared/clients/nats/test/server.go b/pkg/shared/clients/nats/test/server.go index 1f2a4d5f47..abe5722374 100644 --- a/pkg/shared/clients/nats/test/server.go +++ b/pkg/shared/clients/nats/test/server.go @@ -1,6 +1,5 @@ /* Copyright 2022 The Numaproj Authors. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -14,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ -package test +package nats import ( "os" "testing" "github.com/nats-io/nats-server/v2/server" + natstestserver "github.com/nats-io/nats-server/v2/test" ) diff --git a/pkg/shared/clients/redis/options_test.go b/pkg/shared/clients/redis/options_test.go new file mode 100644 index 0000000000..2e2d05e610 --- /dev/null +++ b/pkg/shared/clients/redis/options_test.go @@ -0,0 +1,93 @@ +package redis + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithoutPipelining(t *testing.T) { + opts := &Options{} + option := WithoutPipelining() + option.Apply(opts) + assert.False(t, opts.Pipelining) +} + +func TestWithInfoRefreshInterval(t *testing.T) { + opts := &Options{} + interval := 10 * time.Second + option := WithInfoRefreshInterval(interval) + option.Apply(opts) + assert.Equal(t, interval, opts.InfoRefreshInterval) +} + +func TestWithLagDuration(t *testing.T) { + opts := &Options{} + lag := 5 * time.Minute + option := WithLagDuration(lag) + option.Apply(opts) + assert.Equal(t, lag, opts.LagDuration) +} + +func TestWithReadTimeOut(t *testing.T) { + opts := &Options{} + timeout := 2 * time.Second + option := WithReadTimeOut(timeout) + option.Apply(opts) + assert.Equal(t, timeout, opts.ReadTimeOut) +} + +func TestWithCheckBacklog(t *testing.T) { + opts := &Options{} + option := WithCheckBacklog(true) + option.Apply(opts) + assert.True(t, opts.CheckBackLog) +} + +func TestWithMaxLength(t *testing.T) { + opts := &Options{} + maxLength := int64(100) + option := WithMaxLength(maxLength) + option.Apply(opts) + assert.Equal(t, maxLength, opts.MaxLength) +} + +func TestWithBufferUsageLimit(t *testing.T) { + opts := &Options{} + limit := float64(0.75) + option := WithBufferUsageLimit(limit) + option.Apply(opts) + assert.Equal(t, limit, opts.BufferUsageLimit) +} + +func TestWithRefreshBufferWriteInfo(t *testing.T) { + opts := &Options{} + option := WithRefreshBufferWriteInfo(true) + option.Apply(opts) + assert.True(t, opts.RefreshBufferWriteInfo) +} + +func TestOptionInterfaceImplementation(t *testing.T) { + var _ Option = pipelining(true) + var _ Option = infoRefreshInterval(1 * time.Second) + var _ Option = lagDuration(1 * time.Second) + var _ Option = readTimeOut(1 * time.Second) + var _ Option = checkBackLog(true) + var _ Option = maxLength(100) + var _ Option = bufferUsageLimit(0.5) + var _ Option = refreshBufferWriteInfo(true) +} + +// Test default values +func TestOptionsDefaultValues(t *testing.T) { + opts := Options{} + assert.False(t, opts.Pipelining) + assert.Equal(t, time.Duration(0), opts.InfoRefreshInterval) + assert.Equal(t, time.Duration(0), opts.LagDuration) + assert.Equal(t, time.Duration(0), opts.ReadTimeOut) + assert.False(t, opts.CheckBackLog) + assert.Equal(t, int64(0), opts.MaxLength) + assert.Equal(t, float64(0), opts.BufferUsageLimit) + assert.False(t, opts.RefreshBufferWriteInfo) +} diff --git a/pkg/shared/clients/redis/redis_client_test.go b/pkg/shared/clients/redis/redis_client_test.go index 79e5ccdfa7..09aa765985 100644 --- a/pkg/shared/clients/redis/redis_client_test.go +++ b/pkg/shared/clients/redis/redis_client_test.go @@ -14,22 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. */ +// package redis + +// import ( +// "context" +// "testing" + +// "github.com/redis/go-redis/v9" +// "github.com/stretchr/testify/assert" +// ) + package redis import ( "context" + "errors" + "fmt" + "os" "testing" + "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" ) func TestNewRedisClient(t *testing.T) { t.SkipNow() ctx := context.TODO() - client := NewRedisClient(&redis.UniversalOptions{ + redisOptions := &redis.UniversalOptions{ Addrs: []string{":6379"}, - }) + } + client := NewRedisClient(redisOptions) var stream = "foo" var streamGroup = "foo-group" err := client.CreateStreamGroup(ctx, stream, streamGroup, ReadFromEarliest) @@ -42,3 +59,292 @@ func TestNewRedisClient(t *testing.T) { err = client.CreateStreamGroup(ctx, stream, streamGroup, ReadFromEarliest) assert.Error(t, err) } + +// Mock environment variables +func setEnv(key, value string) { + os.Setenv(key, value) +} + +func unsetEnv(key string) { + os.Unsetenv(key) +} + +func TestNewInClusterRedisClient(t *testing.T) { + // Set environment variables + setEnv(v1alpha1.EnvISBSvcRedisUser, "user") + setEnv(v1alpha1.EnvISBSvcRedisPassword, "password") + setEnv(v1alpha1.EnvISBSvcRedisURL, ":6379") + setEnv(v1alpha1.EnvISBSvcRedisClusterMaxRedirects, "5") + + client := NewInClusterRedisClient() + assert.NotNil(t, client) + + // Cleanup environment variables + unsetEnv(v1alpha1.EnvISBSvcRedisUser) + unsetEnv(v1alpha1.EnvISBSvcRedisPassword) + unsetEnv(v1alpha1.EnvISBSvcRedisURL) + unsetEnv(v1alpha1.EnvISBSvcRedisClusterMaxRedirects) +} + +func TestErrorHelpers(t *testing.T) { + errExist := fmt.Errorf("BUSYGROUP Consumer Group name already exists") + assert.True(t, IsAlreadyExistError(errExist)) + + errNotFound := fmt.Errorf("requires the key to exist") + assert.True(t, NotFoundError(errNotFound)) +} + +func TestGetRedisStreamName(t *testing.T) { + streamName := "test-stream" + expected := "{test-stream}" + got := GetRedisStreamName(streamName) + assert.Equal(t, expected, got) +} + +// Define a mock Redis client +type MockRedisClient struct { + mock.Mock + redis.UniversalClient +} + +func (m *MockRedisClient) XGroupCreateMkStream(ctx context.Context, stream string, group string, start string) *redis.StatusCmd { + args := m.Called(ctx, stream, group, start) + cmd := redis.NewStatusCmd(ctx) + cmd.SetErr(args.Error(0)) + return cmd +} + +func (m *MockRedisClient) XGroupDestroy(ctx context.Context, stream string, group string) *redis.IntCmd { + args := m.Called(ctx, stream, group) + cmd := redis.NewIntCmd(ctx) + cmd.SetErr(args.Error(0)) + return cmd +} + +func (m *MockRedisClient) Del(ctx context.Context, keys ...string) *redis.IntCmd { + args := m.Called(ctx, keys) + cmd := redis.NewIntCmd(ctx) + cmd.SetErr(args.Error(0)) + return cmd +} + +func (m *MockRedisClient) XInfoStream(ctx context.Context, stream string) *redis.XInfoStreamCmd { + args := m.Called(ctx, stream) + cmd := redis.NewXInfoStreamCmd(ctx, stream) + if streamInfo, ok := args.Get(1).(*redis.XInfoStream); ok { + cmd.SetVal(streamInfo) + } + cmd.SetErr(args.Error(0)) + return cmd +} + +func (m *MockRedisClient) XInfoGroups(ctx context.Context, stream string) *redis.XInfoGroupsCmd { + args := m.Called(ctx, stream) + cmd := redis.NewXInfoGroupsCmd(ctx, stream) + if groups, ok := args.Get(1).([]redis.XInfoGroup); ok { + cmd.SetVal(groups) + } + cmd.SetErr(args.Error(0)) + return cmd +} + +func (m *MockRedisClient) XPending(ctx context.Context, stream, group string) *redis.XPendingCmd { + args := m.Called(ctx, stream, group) + cmd := redis.NewXPendingCmd(ctx, stream, group) + if pendingInfo, ok := args.Get(1).(*redis.XPending); ok { + cmd.SetVal(pendingInfo) + } + cmd.SetErr(args.Error(0)) + return cmd +} + +// Test suite for RedisClient +type RedisClientTestSuite struct { + suite.Suite + client *RedisClient + mock *MockRedisClient +} + +func (suite *RedisClientTestSuite) SetupTest() { + suite.mock = new(MockRedisClient) + suite.client = &RedisClient{ + Client: suite.mock, + } +} + +func (suite *RedisClientTestSuite) TestCreateStreamGroup_Success() { + suite.mock.On("XGroupCreateMkStream", mock.Anything, "mystream", "mygroup", "0").Return(nil) + + err := suite.client.CreateStreamGroup(context.Background(), "mystream", "mygroup", "0") + + suite.NoError(err) + suite.mock.AssertCalled(suite.T(), "XGroupCreateMkStream", mock.Anything, "mystream", "mygroup", "0") +} + +func (suite *RedisClientTestSuite) TestCreateStreamGroup_Error() { + suite.mock.On("XGroupCreateMkStream", mock.Anything, "mystream", "mygroup", "0").Return(errors.New("error")) + + err := suite.client.CreateStreamGroup(context.Background(), "mystream", "mygroup", "0") + + suite.Error(err) + suite.mock.AssertCalled(suite.T(), "XGroupCreateMkStream", mock.Anything, "mystream", "mygroup", "0") +} + +func (suite *RedisClientTestSuite) TestDeleteStreamGroup_Success() { + suite.mock.On("XGroupDestroy", mock.Anything, "mystream", "mygroup").Return(nil) + + err := suite.client.DeleteStreamGroup(context.Background(), "mystream", "mygroup") + + suite.NoError(err) + suite.mock.AssertCalled(suite.T(), "XGroupDestroy", mock.Anything, "mystream", "mygroup") +} + +func (suite *RedisClientTestSuite) TestDeleteStreamGroup_Error() { + suite.mock.On("XGroupDestroy", mock.Anything, "mystream", "mygroup").Return(errors.New("error")) + + err := suite.client.DeleteStreamGroup(context.Background(), "mystream", "mygroup") + + suite.Error(err) + suite.mock.AssertCalled(suite.T(), "XGroupDestroy", mock.Anything, "mystream", "mygroup") +} + +func (suite *RedisClientTestSuite) TestDeleteKeys_Success() { + suite.mock.On("Del", mock.Anything, []string{"key1", "key2"}).Return(nil) + + err := suite.client.DeleteKeys(context.Background(), "key1", "key2") + + suite.NoError(err) + suite.mock.AssertCalled(suite.T(), "Del", mock.Anything, []string{"key1", "key2"}) +} + +func (suite *RedisClientTestSuite) TestDeleteKeys_Error() { + suite.mock.On("Del", mock.Anything, []string{"key1", "key2"}).Return(errors.New("error")) + + err := suite.client.DeleteKeys(context.Background(), "key1", "key2") + + suite.Error(err) + suite.mock.AssertCalled(suite.T(), "Del", mock.Anything, []string{"key1", "key2"}) +} + +func (suite *RedisClientTestSuite) TestStreamInfo_Success() { + info := &redis.XInfoStream{ + Length: 5, + } + suite.mock.On("XInfoStream", mock.Anything, "mystream").Return(nil, info) + + result, err := suite.client.StreamInfo(context.Background(), "mystream") + + suite.NoError(err) + suite.Equal(int64(5), result.Length) + suite.mock.AssertCalled(suite.T(), "XInfoStream", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestStreamInfo_Error() { + suite.mock.On("XInfoStream", mock.Anything, "mystream").Return(errors.New("error"), nil) + + result, err := suite.client.StreamInfo(context.Background(), "mystream") + + suite.Error(err) + suite.Nil(result) + suite.mock.AssertCalled(suite.T(), "XInfoStream", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestStreamGroupInfo_Success() { + groupInfo := []redis.XInfoGroup{ + {Name: "group1", Consumers: 2}, + } + suite.mock.On("XInfoGroups", mock.Anything, "mystream").Return(nil, groupInfo) + + result, err := suite.client.StreamGroupInfo(context.Background(), "mystream") + + suite.NoError(err) + suite.Equal("group1", result[0].Name) + suite.mock.AssertCalled(suite.T(), "XInfoGroups", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestStreamGroupInfo_Error() { + suite.mock.On("XInfoGroups", mock.Anything, "mystream").Return(errors.New("error"), nil) + + result, err := suite.client.StreamGroupInfo(context.Background(), "mystream") + + suite.Error(err) + suite.Nil(result) + suite.mock.AssertCalled(suite.T(), "XInfoGroups", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestIsStreamExists_True() { + suite.mock.On("XInfoStream", mock.Anything, "mystream").Return(nil, &redis.XInfoStream{}) + + exists := suite.client.IsStreamExists(context.Background(), "mystream") + + suite.True(exists) + suite.mock.AssertCalled(suite.T(), "XInfoStream", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestIsStreamExists_False() { + suite.mock.On("XInfoStream", mock.Anything, "mystream").Return(errors.New("error"), nil) + + exists := suite.client.IsStreamExists(context.Background(), "mystream") + + suite.False(exists) + suite.mock.AssertCalled(suite.T(), "XInfoStream", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestPendingMsgCount_Success() { + pending := &redis.XPending{ + Count: 10, + } + suite.mock.On("XPending", mock.Anything, "mystream", "mygroup").Return(nil, pending) + + count, err := suite.client.PendingMsgCount(context.Background(), "mystream", "mygroup") + + suite.NoError(err) + suite.Equal(int64(10), count) + suite.mock.AssertCalled(suite.T(), "XPending", mock.Anything, "mystream", "mygroup") +} + +func (suite *RedisClientTestSuite) TestPendingMsgCount_Error() { + suite.mock.On("XPending", mock.Anything, "mystream", "mygroup").Return(errors.New("error"), nil) + + count, err := suite.client.PendingMsgCount(context.Background(), "mystream", "mygroup") + + suite.Error(err) + suite.Equal(int64(0), count) + suite.mock.AssertCalled(suite.T(), "XPending", mock.Anything, "mystream", "mygroup") +} + +func (suite *RedisClientTestSuite) TestIsStreamGroupExists_True() { + groupInfo := []redis.XInfoGroup{ + {Name: "mygroup"}, + } + suite.mock.On("XInfoGroups", mock.Anything, "mystream").Return(nil, groupInfo) + + exists := suite.client.IsStreamGroupExists(context.Background(), "mystream", "mygroup") + + suite.True(exists) + suite.mock.AssertCalled(suite.T(), "XInfoGroups", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestIsStreamGroupExists_False_By_Error() { + suite.mock.On("XInfoGroups", mock.Anything, "mystream").Return(errors.New("error"), nil) + + exists := suite.client.IsStreamGroupExists(context.Background(), "mystream", "mygroup") + + suite.False(exists) + suite.mock.AssertCalled(suite.T(), "XInfoGroups", mock.Anything, "mystream") +} + +func (suite *RedisClientTestSuite) TestIsStreamGroupExists_False_By_Empty() { + groupInfo := []redis.XInfoGroup{} + suite.mock.On("XInfoGroups", mock.Anything, "mystream").Return(nil, groupInfo) + + exists := suite.client.IsStreamGroupExists(context.Background(), "mystream", "mygroup") + + suite.False(exists) + suite.mock.AssertCalled(suite.T(), "XInfoGroups", mock.Anything, "mystream") +} + +// Run the test suite +func TestRedisClientTestSuite(t *testing.T) { + suite.Run(t, new(RedisClientTestSuite)) +} diff --git a/pkg/shared/clients/redis/redis_reader_test.go b/pkg/shared/clients/redis/redis_reader_test.go new file mode 100644 index 0000000000..883c20a4f8 --- /dev/null +++ b/pkg/shared/clients/redis/redis_reader_test.go @@ -0,0 +1,324 @@ +package redis + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" + "go.uber.org/zap/zaptest/observer" + + "github.com/numaproj/numaflow/pkg/isb" +) + +// Mocking the dependencies + +func setup() (*RedisStreamsRead, *redis.Client, *observer.ObservedLogs) { + core, obs := observer.New(zap.DebugLevel) + logger := zap.New(core).Sugar() + + mockClient := redis.NewClient(&redis.Options{ + Addr: "redis:6379", + }) + + r := &RedisStreamsRead{ + Name: "testName", + Stream: "testStream", + Group: "testGroup", + Consumer: "testConsumer", + PartitionIdx: 1, + RedisClient: &RedisClient{ + Client: mockClient, + }, + Options: Options{ + ReadTimeOut: 5 * time.Second, + CheckBackLog: false, + }, + Log: logger, + Metrics: Metrics{ + ReadErrorsInc: func() {}, + ReadsAdd: func(int) {}, + AcksAdd: func(int) {}, + AckErrorsAdd: func(int) {}, + }, + XStreamToMessages: func(xstreams []redis.XStream, messages []*isb.ReadMessage, labels map[string]string) ([]*isb.ReadMessage, error) { + return messages, nil + }, + } + + return r, mockClient, obs +} + +func Test_GetName(t *testing.T) { + reader, _, _ := setup() + assert.Equal(t, "testName", reader.GetName()) +} + +func Test_GetPartitionIdx(t *testing.T) { + reader, _, _ := setup() + assert.Equal(t, int32(1), reader.GetPartitionIdx()) +} +func Test_GetStreamName(t *testing.T) { + reader, _, _ := setup() + assert.Equal(t, "testStream", reader.GetStreamName()) +} +func Test_GetGroupName(t *testing.T) { + reader, _, _ := setup() + assert.Equal(t, "testGroup", reader.GetGroupName()) +} + +func TestRedisStreamsRead_processReadError(t *testing.T) { + t.Run("context canceled error", func(t *testing.T) { + loggerCore, _ := observer.New(zap.DebugLevel) + logger := zap.New(loggerCore).Sugar() + + reader := &RedisStreamsRead{ + Options: Options{CheckBackLog: true}, + Log: logger, + XStreamToMessages: func(xstreams []redis.XStream, messages []*isb.ReadMessage, labels map[string]string) ([]*isb.ReadMessage, error) { + return messages, nil + }, + } + + // Create a context canceled error + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + msgs, err := reader.processReadError(nil, nil, ctx.Err()) + assert.NoError(t, err) + assert.Empty(t, msgs) + }) + + t.Run("redis nil error", func(t *testing.T) { + loggerCore, _ := observer.New(zap.DebugLevel) + logger := zap.New(loggerCore).Sugar() + + reader := &RedisStreamsRead{ + Options: Options{CheckBackLog: true}, + Log: logger, + XStreamToMessages: func(xstreams []redis.XStream, messages []*isb.ReadMessage, labels map[string]string) ([]*isb.ReadMessage, error) { + return messages, nil + }, + } + + msgs, err := reader.processReadError(nil, nil, redis.Nil) + assert.NoError(t, err) + assert.Empty(t, msgs) + }) + + t.Run("generic error with metrics increment", func(t *testing.T) { + loggerCore, _ := observer.New(zap.DebugLevel) + logger := zap.New(loggerCore).Sugar() + + readErrorsIncCalled := false + reader := &RedisStreamsRead{ + Options: Options{CheckBackLog: true}, + Log: logger, + Metrics: Metrics{ + ReadErrorsInc: func() { + readErrorsIncCalled = true + }, + }, + XStreamToMessages: func(xstreams []redis.XStream, messages []*isb.ReadMessage, labels map[string]string) ([]*isb.ReadMessage, error) { + return messages, errors.New("conversion error") + }, + } + + msgs, err := reader.processReadError(nil, nil, errors.New("some error")) + assert.Error(t, err) + assert.Equal(t, "XReadGroup failed, some error", err.Error()) + assert.Empty(t, msgs) + assert.True(t, readErrorsIncCalled) + }) + + t.Run("generic error without metrics increment", func(t *testing.T) { + loggerCore, _ := observer.New(zap.DebugLevel) + logger := zap.New(loggerCore).Sugar() + + reader := &RedisStreamsRead{ + Options: Options{CheckBackLog: true}, + Log: logger, + XStreamToMessages: func(xstreams []redis.XStream, messages []*isb.ReadMessage, labels map[string]string) ([]*isb.ReadMessage, error) { + return messages, errors.New("conversion error") + }, + } + + msgs, err := reader.processReadError(nil, nil, errors.New("some error")) + assert.Error(t, err) + assert.Equal(t, "XReadGroup failed, some error", err.Error()) + assert.Empty(t, msgs) + }) +} + +func TestRedisStreamsRead_NoAck(t *testing.T) { + reader, _, _ := setup() + + offsets := []isb.Offset{isb.NewSimpleStringPartitionOffset("0-0", 0)} + reader.NoAck(context.Background(), offsets) + // NoAck does nothing, so just ensure no panic +} + +// Mock implementation for isb.Offset +type SimpleOffset struct { + offset string + partitionIdx int32 +} + +func (o SimpleOffset) String() string { + return o.offset +} + +func (o SimpleOffset) Sequence() (int64, error) { + return 0, nil // Not implemented for our mock +} + +func (o SimpleOffset) AckIt() error { + return nil // Mock implementation +} + +func (o SimpleOffset) NoAck() error { + return nil // Mock implementation +} + +func (o SimpleOffset) PartitionIdx() int32 { + return o.partitionIdx +} + +func (m *MockRedisClient) XReadGroup(ctx context.Context, a *redis.XReadGroupArgs) *redis.XStreamSliceCmd { + args := m.Called(ctx, a) + cmd := redis.NewXStreamSliceCmd(ctx) + cmd.SetVal(args.Get(0).([]redis.XStream)) + cmd.SetErr(args.Error(1)) + return cmd +} + +func (m *MockRedisClient) XAck(ctx context.Context, stream, group string, ids ...string) *redis.IntCmd { + args := m.Called(ctx, stream, group, ids) + cmd := redis.NewIntCmd(ctx) + cmd.SetErr(args.Error(0)) + return cmd +} + +// Define custom metrics increments and add functions +type MockMetrics struct { + mock.Mock +} + +func (m *MockMetrics) ReadErrorsInc() { + m.Called() +} + +func (m *MockMetrics) ReadsAdd(count int) { + m.Called(count) +} + +func (m *MockMetrics) AcksAdd(count int) { + m.Called(count) +} + +func (m *MockMetrics) AckErrorsAdd(count int) { + m.Called(count) +} + +// Test suite for RedisStreamsRead +type RedisStreamsReadTestSuite struct { + suite.Suite + client *RedisStreamsRead + mock *MockRedisClient + metrics *MockMetrics +} + +func (suite *RedisStreamsReadTestSuite) SetupTest() { + suite.mock = new(MockRedisClient) + suite.metrics = new(MockMetrics) + logger := zaptest.NewLogger(suite.T()) + suite.client = &RedisStreamsRead{ + Name: "test", + Stream: "mystream", + Group: "mygroup", + Consumer: "consumer1", + PartitionIdx: 0, + RedisClient: &RedisClient{ + Client: suite.mock, + }, + Options: Options{ + CheckBackLog: true, + ReadTimeOut: 10 * time.Second, + }, + Log: logger.Sugar(), + Metrics: Metrics{ + ReadErrorsInc: suite.metrics.ReadErrorsInc, + ReadsAdd: suite.metrics.ReadsAdd, + AcksAdd: suite.metrics.AcksAdd, + AckErrorsAdd: suite.metrics.AckErrorsAdd, + }, + XStreamToMessages: func(xstreams []redis.XStream, messages []*isb.ReadMessage, labels map[string]string) ([]*isb.ReadMessage, error) { + return messages, nil + }, + } +} + +func (suite *RedisStreamsReadTestSuite) TestRead_Success_Backlog() { + xstreams := []redis.XStream{ + {Messages: []redis.XMessage{ + {ID: "1", Values: map[string]interface{}{"data": "value1"}}, + }}, + } + suite.mock.On("XReadGroup", mock.Anything, mock.Anything).Return(xstreams, nil) + suite.metrics.On("ReadsAdd", 1).Return() + + messages, err := suite.client.Read(context.Background(), 1) + + suite.NoError(err) + suite.Equal(0, len(messages)) + suite.metrics.AssertCalled(suite.T(), "ReadsAdd", 1) +} + +func (suite *RedisStreamsReadTestSuite) TestAck_Success() { + offsets := []isb.Offset{SimpleOffset{offset: "1"}} + dedupOffsets := []string{"1"} + suite.mock.On("XAck", mock.Anything, "mystream", "mygroup", dedupOffsets).Return(nil) + suite.metrics.On("AcksAdd", 1).Return() + + errs := suite.client.Ack(context.Background(), offsets) + + for _, err := range errs { + suite.NoError(err) + } + suite.metrics.AssertCalled(suite.T(), "AcksAdd", 1) +} + +func (suite *RedisStreamsReadTestSuite) TestAck_Error() { + offsets := []isb.Offset{SimpleOffset{offset: "1"}} + dedupOffsets := []string{"1"} + suite.mock.On("XAck", mock.Anything, "mystream", "mygroup", dedupOffsets).Return(errors.New("ack error")) + suite.metrics.On("AckErrorsAdd", 1).Return() + + errs := suite.client.Ack(context.Background(), offsets) + + for _, err := range errs { + suite.Error(err) + suite.Equal("ack error", err.Error()) + } + suite.metrics.AssertCalled(suite.T(), "AckErrorsAdd", 1) +} + +func (suite *RedisStreamsReadTestSuite) TestPending_Error() { + suite.mock.On("XInfoGroups", mock.Anything, "mystream").Return(nil, errors.New("info error")) + + lag, err := suite.client.Pending(context.Background()) + + suite.Error(err) + suite.Equal(isb.PendingNotAvailable, lag) +} + +// Run the test suite +func TestRedisStreamsReadTestSuite(t *testing.T) { + suite.Run(t, new(RedisStreamsReadTestSuite)) +} diff --git a/pkg/shared/idlehandler/idlehandler_test.go b/pkg/shared/idlehandler/idlehandler_test.go new file mode 100644 index 0000000000..8e50685648 --- /dev/null +++ b/pkg/shared/idlehandler/idlehandler_test.go @@ -0,0 +1,120 @@ +package idlehandler + +import ( + "context" + "testing" + + "github.com/stretchr/testify/mock" + "go.uber.org/zap/zaptest" + + dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" + "github.com/numaproj/numaflow/pkg/isb" + "github.com/numaproj/numaflow/pkg/watermark/wmb" +) + +// Define a mock BufferWriter +type MockBufferWriter struct { + mock.Mock +} + +func (m *MockBufferWriter) GetName() string { + args := m.Called() + return args.String(0) +} + +func (m *MockBufferWriter) GetPartitionIdx() int32 { + args := m.Called() + return args.Get(0).(int32) +} + +func (m *MockBufferWriter) Write(ctx context.Context, messages []isb.Message) ([]isb.Offset, []error) { + args := m.Called(ctx, messages) + return args.Get(0).([]isb.Offset), args.Get(1).([]error) +} + +func (m *MockBufferWriter) Close() error { + args := m.Called() + return args.Error(0) +} + +// Define a mock Publisher +type MockPublisher struct { + mock.Mock +} + +func (m *MockPublisher) PublishWatermark(wm wmb.Watermark, offset isb.Offset, partition int32) { + m.Called(wm, offset, partition) +} + +func (m *MockPublisher) PublishIdleWatermark(wm wmb.Watermark, offset isb.Offset, partition int32) { + m.Called(wm, offset, partition) +} + +func (m *MockPublisher) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockPublisher) GetLatestWatermark() wmb.Watermark { + args := m.Called() + return args.Get(0).(wmb.Watermark) +} + +// Define a mock IdleManager +type MockIdleManager struct { + mock.Mock +} + +func (m *MockIdleManager) MarkIdle(fromBufferPartitionIndex int32, toPartitionName string) { + m.Called(fromBufferPartitionIndex, toPartitionName) +} + +func (m *MockIdleManager) NeedToSendCtrlMsg(toPartitionName string) bool { + args := m.Called(toPartitionName) + return args.Bool(0) +} + +func (m *MockIdleManager) Update(fromBufferPartitionIndex int32, toPartitionName string, offset isb.Offset) { + m.Called(fromBufferPartitionIndex, toPartitionName, offset) +} + +func (m *MockIdleManager) Get(toPartitionName string) isb.Offset { + args := m.Called(toPartitionName) + return args.Get(0).(isb.Offset) +} + +func (m *MockIdleManager) MarkActive(fromBufferPartitionIndex int32, toPartitionName string) { + m.Called(fromBufferPartitionIndex, toPartitionName) +} + +// Unit test the PublishIdleWatermark function +func TestPublishIdleWatermark_SinkVertex(t *testing.T) { + ctx := context.Background() + fromBufferPartitionIndex := int32(1) + vertexName := "test-vertex" + pipelineName := "test-pipeline" + vertexType := dfv1.VertexTypeSink + vertexReplica := int32(0) + wm := wmb.Watermark{} + + logger := zaptest.NewLogger(t).Sugar() + + toBufferPartition := new(MockBufferWriter) + toBufferPartition.On("GetName").Return("test-partition") + toBufferPartition.On("GetPartitionIdx").Return(int32(0)) + + wmPublisher := new(MockPublisher) + idleManager := new(MockIdleManager) + + idleManager.On("MarkIdle", fromBufferPartitionIndex, "test-partition").Return() + idleManager.On("NeedToSendCtrlMsg", "test-partition").Return(false) + wmPublisher.On("PublishIdleWatermark", wm, nil, int32(0)).Return() + wmPublisher.On("GetLatestWatermark").Return(wm) + + PublishIdleWatermark(ctx, fromBufferPartitionIndex, toBufferPartition, wmPublisher, idleManager, logger, vertexName, pipelineName, vertexType, vertexReplica, wm) + + // Verify that the mock methods are called as expected + idleManager.AssertCalled(t, "MarkIdle", fromBufferPartitionIndex, "test-partition") + // idleManager.AssertNotCalled(t, "NeedToSendCtrlMsg", "test-partition") + wmPublisher.AssertCalled(t, "PublishIdleWatermark", wm, nil, int32(0)) +} diff --git a/pkg/shared/idlehandler/source_idlehandler_test.go b/pkg/shared/idlehandler/source_idlehandler_test.go index 82b29675b6..b34d8b5adf 100644 --- a/pkg/shared/idlehandler/source_idlehandler_test.go +++ b/pkg/shared/idlehandler/source_idlehandler_test.go @@ -21,6 +21,10 @@ import ( "time" dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" + "github.com/numaproj/numaflow/pkg/isb" + "github.com/numaproj/numaflow/pkg/watermark/wmb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -108,3 +112,44 @@ func TestSourceIdleHandler_IsSourceIdling(t *testing.T) { }) } } + +// Mock SourceFetcher and SourcePublisher for testing +type MockSourceFetcher struct { + mock.Mock +} + +func (m *MockSourceFetcher) ComputeWatermark() wmb.Watermark { + args := m.Called() + return args.Get(0).(wmb.Watermark) +} + +func (m *MockSourceFetcher) ComputeHeadWatermark(fromPartitionIdx int32) wmb.Watermark { + args := m.Called(fromPartitionIdx) + return args.Get(0).(wmb.Watermark) +} + +type MockSourcePublisher struct { + mock.Mock +} + +func (m *MockSourcePublisher) PublishIdleWatermarks(watermark time.Time, partitions []int32) { + m.Called(watermark, partitions) +} + +func (m *MockSourcePublisher) PublishSourceWatermarks(in []*isb.ReadMessage) { + m.Called(in) +} + +func TestSourceIdleHandler_Reset(t *testing.T) { + config := &dfv1.Watermark{} + mockFetcher := new(MockSourceFetcher) + mockPublisher := new(MockSourcePublisher) + + handler := NewSourceIdleHandler(config, mockFetcher, mockPublisher) + handler.lastIdleWmPublishedTime = time.Now() + + handler.Reset() + + assert.WithinDuration(t, time.Now(), handler.updatedTS, time.Second) + assert.Equal(t, time.UnixMilli(-1), handler.lastIdleWmPublishedTime) +} diff --git a/pkg/shared/kvs/jetstream/kv_store_test.go b/pkg/shared/kvs/jetstream/kv_store_test.go index c864fb4e34..1d696b0a36 100644 --- a/pkg/shared/kvs/jetstream/kv_store_test.go +++ b/pkg/shared/kvs/jetstream/kv_store_test.go @@ -26,6 +26,7 @@ import ( "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" "github.com/numaproj/numaflow/pkg/shared/kvs" ) @@ -39,7 +40,7 @@ func TestJetStreamKVStoreOperations(t *testing.T) { s := natstest.RunJetStreamServer(t) defer natstest.ShutdownJetStreamServer(t, s) - testClient := natstest.JetStreamClient(t, s) + testClient := natsclient.NewTestClientWithServer(t, s) defer testClient.Close() js, err := testClient.JetStreamContext() @@ -100,7 +101,7 @@ func TestJetStreamKVStoreWatch(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - testClient := natstest.JetStreamClient(t, s) + testClient := natsclient.NewTestClientWithServer(t, s) defer testClient.Close() js, err := testClient.JetStreamContext() @@ -123,7 +124,6 @@ func TestJetStreamKVStoreWatch(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - // write some key value entries inside a go routine go func() { defer wg.Done() // write 100 key value pairs @@ -132,7 +132,6 @@ func TestJetStreamKVStoreWatch(t *testing.T) { assert.NoError(t, err) } - // delete 50 key value pairs for i := 0; i < 50; i++ { err = kvStore.DeleteKey(ctx, fmt.Sprintf("key-%d", i)) } @@ -177,7 +176,7 @@ func TestJetStreamKVWithoutUpdates(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - testClient := natstest.JetStreamClient(t, s) + testClient := natsclient.NewTestClientWithServer(t, s) defer testClient.Close() js, err := testClient.JetStreamContext() @@ -245,3 +244,177 @@ watchLoop: wg.Wait() assert.Equal(t, 100, kvPutCount) } + +func TestJetStreamKVStoreErrorBinding(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kvName := "errorKVStore" + + s := natstest.RunJetStreamServer(t) + defer natstest.ShutdownJetStreamServer(t, s) + + testClient := natsclient.NewTestClientWithServer(t, s) + defer testClient.Close() + + // Intentionally binding to a non-existing bucket to simulate error + _, err := NewKVJetStreamKVStore(ctx, kvName+"-non-existent", testClient) + assert.Error(t, err) +} + +func TestJetStreamKVStoreGetValueError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kvName := "testJetStreamKVStore" + + s := natstest.RunJetStreamServer(t) + defer natstest.ShutdownJetStreamServer(t, s) + + testClient := natsclient.NewTestClientWithServer(t, s) + defer testClient.Close() + + js, err := testClient.JetStreamContext() + assert.NoError(t, err) + + _, err = js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: kvName, + }) + assert.NoError(t, err) + + defer func() { + err = js.DeleteKeyValue(kvName) + assert.NoError(t, err) + }() + + kvStore, err := NewKVJetStreamKVStore(ctx, kvName, testClient) + assert.NoError(t, err) + defer kvStore.Close() + + // Attempt to get a non-existent key to simulate error + _, err = kvStore.GetValue(ctx, "non-existent-key") + assert.Error(t, err) +} + +func TestJetStreamKVStoreListKeysError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kvName := "testJetStreamKVStore" + + s := natstest.RunJetStreamServer(t) + defer natstest.ShutdownJetStreamServer(t, s) + + testClient := natsclient.NewTestClientWithServer(t, s) + defer testClient.Close() + + js, err := testClient.JetStreamContext() + assert.NoError(t, err) + + _, err = js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: kvName, + }) + assert.NoError(t, err) + + defer func() { + err = js.DeleteKeyValue(kvName) + assert.NoError(t, err) + }() + + kvStore, err := NewKVJetStreamKVStore(ctx, kvName, testClient) + assert.NoError(t, err) + defer kvStore.Close() + + // Attempt to list keys in a non-existent bucket to simulate error + _, err = NewKVJetStreamKVStore(ctx, "non-existent-bucket", testClient) + assert.Error(t, err) +} + +func TestJetStreamKVStorePutKVError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kvName := "testJetStreamKVStore" + + s := natstest.RunJetStreamServer(t) + defer natstest.ShutdownJetStreamServer(t, s) + + testClient := natsclient.NewTestClientWithServer(t, s) + defer testClient.Close() + + js, err := testClient.JetStreamContext() + assert.NoError(t, err) + + _, err = js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: kvName, + }) + assert.NoError(t, err) + + defer func() { + err = js.DeleteKeyValue(kvName) + assert.NoError(t, err) + }() + + kvStore, err := NewKVJetStreamKVStore(ctx, kvName, testClient) + assert.NoError(t, err) + defer kvStore.Close() + + // Forcing a Put error by using a bad key name + err = kvStore.PutKV(ctx, "", []byte("value")) + assert.Error(t, err) +} + +func TestJetStreamKVStoreWatchError(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + kvName := "testJetStreamKVStore" + + s := natstest.RunJetStreamServer(t) + defer natstest.ShutdownJetStreamServer(t, s) + + testClient := natsclient.NewTestClientWithServer(t, s) + defer testClient.Close() + + js, err := testClient.JetStreamContext() + assert.NoError(t, err) + + _, err = js.CreateKeyValue(&nats.KeyValueConfig{ + Bucket: kvName, + }) + assert.NoError(t, err) + + defer func() { + err = js.DeleteKeyValue(kvName) + assert.NoError(t, err) + }() + + kvStore, err := NewKVJetStreamKVStore(ctx, kvName, testClient) + assert.NoError(t, err) + defer kvStore.Close() + + // Forcing a watch error by canceling the context immediately + ctx, cancel = context.WithCancel(context.Background()) + defer cancel() + cancel() + + kvCh := kvStore.Watch(ctx) + _, ok := <-kvCh + assert.False(t, ok) +} +func TestJetStreamKVStoreWithContextDone(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 0) + defer cancel() + + kvName := "testJetStreamKVStore" + + s := natstest.RunJetStreamServer(t) + defer natstest.ShutdownJetStreamServer(t, s) + + testClient := natsclient.NewTestClientWithServer(t, s) + defer testClient.Close() + + kvStore, err := NewKVJetStreamKVStore(ctx, kvName, testClient) + assert.Error(t, err) + assert.Nil(t, kvStore) +} diff --git a/pkg/shared/util/json_test.go b/pkg/shared/util/json_test.go index 76baef9d2f..b84420331d 100644 --- a/pkg/shared/util/json_test.go +++ b/pkg/shared/util/json_test.go @@ -24,10 +24,39 @@ import ( func TestMustJson(t *testing.T) { assert.Equal(t, "1", MustJSON(1)) + t.Run("error", func(t *testing.T) { + assert.Panics(t, func() { + type InvalidMarshal struct { + Channel chan int + } + // This should cause a panic because `json.Marshal` cannot handle channels + MustJSON(&InvalidMarshal{Channel: make(chan int)}) + }) + }) } func TestUnJSON(t *testing.T) { var in int MustUnJSON("1", &in) assert.Equal(t, 1, in) + + t.Run("invalid json", func(t *testing.T) { + assert.Panics(t, func() { + var in int + MustUnJSON("invalid json", &in) + }) + }) + + t.Run("invalid type for in", func(t *testing.T) { + assert.Panics(t, func() { + MustUnJSON("1", 1) + }) + }) + + t.Run("unsupported type for v", func(t *testing.T) { + assert.Panics(t, func() { + var in int + MustUnJSON(1, &in) + }) + }) } diff --git a/pkg/shared/util/kubeconfig_test.go b/pkg/shared/util/kubeconfig_test.go new file mode 100644 index 0000000000..a635a713d6 --- /dev/null +++ b/pkg/shared/util/kubeconfig_test.go @@ -0,0 +1,38 @@ +package util + +import ( + "os" + "path/filepath" + "testing" + + "k8s.io/client-go/util/homedir" + + "github.com/stretchr/testify/assert" +) + +func TestK8sRestConfig(t *testing.T) { + t.Run("K8sRestConfig returns error when KUBECONFIG is invalid", func(t *testing.T) { + // Setup the environment to simulate an invalid KUBECONFIG + kubeconfig := "invalid-kubeconfig" + os.Setenv("KUBECONFIG", kubeconfig) + defer os.Unsetenv("KUBECONFIG") + + config, err := K8sRestConfig() + assert.NotNil(t, err) + assert.Nil(t, config) + }) + +} + +func TestK8sRestConfig_blank(t *testing.T) { + os.Unsetenv("KUBECONFIG") + + // Ensure the default kubeconfig does not exist + homeDir := homedir.HomeDir() + defaultKubeconfigPath := filepath.Join(homeDir, ".kube", "config") + os.Remove(defaultKubeconfigPath) + + restConfig, err := K8sRestConfig() + assert.Error(t, err) + assert.Nil(t, restConfig) +} diff --git a/pkg/shared/util/sasl_config_test.go b/pkg/shared/util/sasl_config_test.go index 40ff32f474..7cbec8a876 100644 --- a/pkg/shared/util/sasl_config_test.go +++ b/pkg/shared/util/sasl_config_test.go @@ -26,6 +26,7 @@ import ( ) func TestSaslConfiguration(t *testing.T) { + mockedVolumes := mockedVolumes{ volumeSecrets: map[struct { objectName string @@ -99,4 +100,184 @@ func TestSaslConfiguration(t *testing.T) { assert.Equal(t, "user", config.User) assert.Equal(t, "password", config.Password) }) + + t.Run("gssapi", func(t *testing.T) { + plain := dfv1.SASLTypePlaintext + temp := dfv1.SASL{ + Mechanism: &plain, + Plain: credentials, + } + config, err := GetGSSAPIConfig(temp.GSSAPI) + assert.NoError(t, err) + assert.Nil(t, config) + + }) +} + +func TestGetGSSAPIConfig_NilConfig(t *testing.T) { + config, err := GetGSSAPIConfig(nil) + assert.NoError(t, err) + assert.Nil(t, config) +} + +func TestGetGSSAPIConfig_InvalidAuthType(t *testing.T) { + + var authType dfv1.KRB5AuthType = "anytpe" + + config := &dfv1.GSSAPI{ + ServiceName: "service", + Realm: "realm", + AuthType: &authType, + } + + _, err := GetGSSAPIConfig(config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse GSSAPI AuthType") +} + +func TestXDGSCRAMClient_Begin_SHA256(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: SHA256} + err := client.Begin("username", "password", "") + assert.NoError(t, err) + assert.NotNil(t, client.Client) + assert.NotNil(t, client.ClientConversation) +} + +func TestXDGSCRAMClient_Begin_SHA512(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: SHA512} + err := client.Begin("username", "password", "") + assert.NoError(t, err) + assert.NotNil(t, client.Client) + assert.NotNil(t, client.ClientConversation) +} + +func TestXDGSCRAMClient_Step(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: SHA256} + err := client.Begin("username", "password", "") + assert.NoError(t, err) + + response, err := client.Step("challenge") + assert.NoError(t, err) + assert.NotEmpty(t, response) +} + +func TestXDGSCRAMClient_Done(t *testing.T) { + client := &XDGSCRAMClient{HashGeneratorFcn: SHA256} + err := client.Begin("username", "password", "") + assert.NoError(t, err) + + _, err = client.Step("challenge") + assert.NoError(t, err) + assert.False(t, client.Done()) +} + +type mockGSSAPI struct { + ServiceName string + Realm string + UsernameSecret *corev1.SecretKeySelector + AuthType *dfv1.KRB5AuthType + PasswordSecret *corev1.SecretKeySelector + KeytabSecret *corev1.SecretKeySelector + KerberosConfigSecret *corev1.SecretKeySelector +} + +func TestGetGSSAPIConfig(t *testing.T) { + + authType := dfv1.KRB5UserAuth + tests := []struct { + name string + config *mockGSSAPI + want *sarama.GSSAPIConfig + wantErr bool + }{ + { + name: "invalid auth type", + config: &mockGSSAPI{ + ServiceName: "testService", + Realm: "testRealm", + AuthType: new(dfv1.KRB5AuthType), // invalid auth type + }, + want: nil, + wantErr: true, + }, + { + name: "error fetching username secret", + config: &mockGSSAPI{ + ServiceName: "testService", + Realm: "testRealm", + AuthType: &authType, + UsernameSecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "error"}, + Key: "username", + }, + }, + want: nil, + wantErr: true, + }, + { + name: "error fetching Kerbos config secret", + config: &mockGSSAPI{ + ServiceName: "testService", + Realm: "testRealm", + AuthType: &authType, + KeytabSecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "error"}, + Key: "keytab", + }, + }, + want: nil, + wantErr: true, + }, + + { + name: "error fetching keytab file", + config: &mockGSSAPI{ + ServiceName: "testService", + Realm: "testRealm", + AuthType: &authType, + KerberosConfigSecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "error"}, + Key: "KerberosConfig", + }, + }, + want: nil, + wantErr: true, + }, + + { + name: "error fetching password", + config: &mockGSSAPI{ + ServiceName: "testService", + Realm: "testRealm", + AuthType: &authType, + PasswordSecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "error"}, + Key: "PasswordS", + }, + }, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := GetGSSAPIConfig(&dfv1.GSSAPI{ + ServiceName: tt.config.ServiceName, + Realm: tt.config.Realm, + UsernameSecret: tt.config.UsernameSecret, + AuthType: tt.config.AuthType, + PasswordSecret: tt.config.PasswordSecret, + KeytabSecret: tt.config.KeytabSecret, + KerberosConfigSecret: tt.config.KerberosConfigSecret, + }) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + } + }) + } } diff --git a/pkg/shared/util/tls_config_test.go b/pkg/shared/util/tls_config_test.go new file mode 100644 index 0000000000..9e320eae74 --- /dev/null +++ b/pkg/shared/util/tls_config_test.go @@ -0,0 +1,74 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + corev1 "k8s.io/api/core/v1" + + dfv1 "github.com/numaproj/numaflow/pkg/apis/numaflow/v1alpha1" +) + +// Mock structure +type Secret struct { + Name string +} + +func TestGetTLSConfig_NilConfig(t *testing.T) { + config, err := GetTLSConfig(nil) + assert.NoError(t, err) + assert.Nil(t, config) +} + +func MockTLSObject() *dfv1.TLS { + return &dfv1.TLS{ + InsecureSkipVerify: true, + CACertSecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "ca-cert-secret"}, + Key: "caCert", + }, + CertSecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "cert-secret"}, + Key: "cert", + }, + KeySecret: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "key-secret"}, + Key: "key", + }, + } +} + +type MockSecretVolumePath struct { + mock.Mock +} + +func (m *MockSecretVolumePath) GetSecretVolumePath(secret *corev1.SecretKeySelector) (string, error) { + args := m.Called(secret) + return args.String(0), args.Error(1) +} + +func TestGetTLSConfig_CertWithoutKey(t *testing.T) { + mockVolumePath := new(MockSecretVolumePath) + mockTLS := MockTLSObject() + mockTLS.KeySecret = nil + + mockVolumePath.On("GetSecretVolumePath", mockTLS.CertSecret).Return("/mocked/path/cert-secret", nil) + + _, err := GetTLSConfig(mockTLS) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid tls config") +} + +func TestGetTLSConfig_KeyWithoutCert(t *testing.T) { + mockVolumePath := new(MockSecretVolumePath) + mockTLS := MockTLSObject() + mockTLS.CertSecret = nil + + mockVolumePath.On("GetSecretVolumePath", mockTLS.KeySecret).Return("/mocked/path/key-secret", nil) + + _, err := GetTLSConfig(mockTLS) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid tls config") + +} diff --git a/pkg/shared/util/volume_test.go b/pkg/shared/util/volume_test.go index eb62a3e497..52d58dff20 100644 --- a/pkg/shared/util/volume_test.go +++ b/pkg/shared/util/volume_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" corev1 "k8s.io/api/core/v1" ) @@ -152,3 +153,41 @@ func Test_GetSecretVolumePath(t *testing.T) { assert.Nil(t, e) assert.Equal(t, "/var/numaflow/secrets/test-secret/test-key", p) } + +type MockFileReader struct { + mock.Mock +} + +func (m *MockFileReader) getConfigMapFromVolume(selector *corev1.ConfigMapKeySelector) (string, error) { + args := m.Called(selector) + return args.String(0), args.Error(1) +} + +func (m *MockFileReader) getSecretFromVolume(selector *corev1.SecretKeySelector) (string, error) { + args := m.Called(selector) + return args.String(0), args.Error(1) +} + +func TestGetConfigMapFromVolume_FileNotFound(t *testing.T) { + // file not found error + selector := &corev1.ConfigMapKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "test-configmap"}, + Key: "test-key", + } + + _, err := GetConfigMapFromVolume(selector) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get configMap value") +} + +func TestGetSecretFromVolume_FileNotFound(t *testing.T) { + + selector := &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{Name: "test-secret"}, + Key: "test-key", + } + + _, err := GetSecretFromVolume(selector) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to get secret value") +} diff --git a/pkg/sideinputs/initializer/initializer_test.go b/pkg/sideinputs/initializer/initializer_test.go index 03b64d15bc..2eb264c978 100644 --- a/pkg/sideinputs/initializer/initializer_test.go +++ b/pkg/sideinputs/initializer/initializer_test.go @@ -27,6 +27,7 @@ import ( "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" "github.com/numaproj/numaflow/pkg/shared/kvs/jetstream" "github.com/numaproj/numaflow/pkg/sideinputs/utils" @@ -58,7 +59,7 @@ func TestSideInputsInitializer_Success(t *testing.T) { defer cancel() // connect to NATS - nc := natstest.JetStreamClient(t, s) + nc := natsclient.NewTestClientWithServer(t, s) defer nc.Close() // create JetStream Context @@ -127,7 +128,7 @@ func TestSideInputsTimeout(t *testing.T) { defer cancel() // connect to NATS - nc := natstest.JetStreamClient(t, s) + nc := natsclient.NewTestClientWithServer(t, s) defer nc.Close() // create JetStream Context diff --git a/pkg/sideinputs/synchronizer/synchronizer_test.go b/pkg/sideinputs/synchronizer/synchronizer_test.go index 9338929ec3..fec8ea58e7 100644 --- a/pkg/sideinputs/synchronizer/synchronizer_test.go +++ b/pkg/sideinputs/synchronizer/synchronizer_test.go @@ -27,6 +27,7 @@ import ( "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" "github.com/numaproj/numaflow/pkg/shared/kvs/jetstream" "github.com/numaproj/numaflow/pkg/sideinputs/utils" @@ -65,7 +66,7 @@ func TestSideInputsValueUpdates(t *testing.T) { defer cancel() // connect to NATS - nc := natstest.JetStreamClient(t, s) + nc := natsclient.NewTestClientWithServer(t, s) defer nc.Close() // create JetStream Context diff --git a/pkg/watermark/fetch/edge_fetcher_test.go b/pkg/watermark/fetch/edge_fetcher_test.go index a55fdb2d5f..78f39b9ed9 100644 --- a/pkg/watermark/fetch/edge_fetcher_test.go +++ b/pkg/watermark/fetch/edge_fetcher_test.go @@ -28,6 +28,7 @@ import ( "go.uber.org/zap/zaptest" "github.com/numaproj/numaflow/pkg/isb" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" "github.com/numaproj/numaflow/pkg/shared/kvs" "github.com/numaproj/numaflow/pkg/shared/kvs/jetstream" @@ -1014,7 +1015,7 @@ func TestFetcherWithSameOTBucketWithSinglePartition(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() // connect to NATS - nc := natstest.JetStreamClient(t, s) + nc := natsclient.NewTestClientWithServer(t, s) defer nc.Close() // create JetStream Context @@ -1051,7 +1052,7 @@ func TestFetcherWithSameOTBucketWithSinglePartition(t *testing.T) { defer func() { _ = js.DeleteKeyValue(keyspace + "_OT") }() assert.NoError(t, err) - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() // create hbStore @@ -1311,7 +1312,7 @@ func TestFetcherWithSameOTBucketWithMultiplePartition(t *testing.T) { defer cancel() // connect to NATS - nc := natstest.JetStreamClient(t, s) + nc := natsclient.NewTestClientWithServer(t, s) defer nc.Close() // create JetStream Context @@ -1348,7 +1349,7 @@ func TestFetcherWithSameOTBucketWithMultiplePartition(t *testing.T) { defer func() { _ = js.DeleteKeyValue(keyspace + "_OT") }() assert.NoError(t, err) - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() // create hbStore diff --git a/pkg/watermark/publish/publisher_test.go b/pkg/watermark/publish/publisher_test.go index 7cd029206b..4a182c6b92 100644 --- a/pkg/watermark/publish/publisher_test.go +++ b/pkg/watermark/publish/publisher_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/numaproj/numaflow/pkg/isb" + natsclient "github.com/numaproj/numaflow/pkg/shared/clients/nats" natstest "github.com/numaproj/numaflow/pkg/shared/clients/nats/test" "github.com/numaproj/numaflow/pkg/watermark/entity" "github.com/numaproj/numaflow/pkg/watermark/store" @@ -49,7 +50,7 @@ func TestPublisherWithSharedOTBucket(t *testing.T) { var ctx = context.Background() - defaultJetStreamClient := natstest.JetStreamClient(t, s) + defaultJetStreamClient := natsclient.NewTestClientWithServer(t, s) defer defaultJetStreamClient.Close() js, err := defaultJetStreamClient.JetStreamContext() assert.NoError(t, err)