From 0473bdf0418b9983782744e66a140a0bc7046c2e Mon Sep 17 00:00:00 2001 From: Pedro Soares Date: Thu, 6 Feb 2025 22:08:18 -0300 Subject: [PATCH] fix(nats): drain connections upon error When failing to connect, drain connection rather than closing directly. This is safer than calling close and will wait for connections before closing, gracefully handling ongoing reconnects and avoid leaking goroutines in test --- .github/workflows/tests.yaml | 2 +- cluster/nats_rpc_client_test.go | 1 + cluster/nats_rpc_common.go | 48 ++++++- cluster/nats_rpc_common_test.go | 233 ++++++++++++++------------------ cluster/nats_rpc_server_test.go | 56 ++++++-- 5 files changed, 196 insertions(+), 144 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2853736c..8010ba0a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -43,7 +43,7 @@ jobs: - name: Send coverage env: COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: ~/go/bin/goveralls -coverprofile=coverprofile.out -service=github + run: go run goveralls -coverprofile=coverprofile.out -service=github e2e-test-nats: name: Nats Test End to End runs-on: ubuntu-latest diff --git a/cluster/nats_rpc_client_test.go b/cluster/nats_rpc_client_test.go index 19eedb8e..72317d20 100644 --- a/cluster/nats_rpc_client_test.go +++ b/cluster/nats_rpc_client_test.go @@ -441,6 +441,7 @@ func TestNatsRPCClientCall(t *testing.T) { t.Run(table.name, func(t *testing.T) { ctrl := gomock.NewController(t) conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) + defer conn.Close() assert.NoError(t, err) sv2 := getServer() diff --git a/cluster/nats_rpc_common.go b/cluster/nats_rpc_common.go index 17d9b591..fa79e474 100644 --- a/cluster/nats_rpc_common.go +++ b/cluster/nats_rpc_common.go @@ -34,13 +34,46 @@ func getChannel(serverType, serverID string) string { return fmt.Sprintf("pitaya/servers/%s/%s", serverType, serverID) } +func drainAndClose(nc *nats.Conn) error { + if nc == nil { + return nil + } + // Drain connection (this will flush any pending messages and prevent new ones) + err := nc.Drain() + if err != nil { + logger.Log.Warnf("error draining nats connection: %v", err) + // Even if drain fails, try to close + nc.Close() + return err + } + + // Wait for drain to complete with timeout + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for nc.IsDraining() { + select { + case <-ticker.C: + continue + case <-timeout: + logger.Log.Warn("drain timeout exceeded, forcing close") + nc.Close() + return fmt.Errorf("drain timeout exceeded") + } + } + + // Close will happen automatically after drain completes + return nil +} + func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) { connectedCh := make(chan bool) initialConnectErrorCh := make(chan error) natsOptions := append( options, - nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { - logger.Log.Warnf("disconnected from nats! Reason: %q\n", err) + nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { + logger.Log.Warnf("disconnected from nats (%s)! Reason: %q\n", nc.ConnectedAddr(), err) }), nats.ReconnectHandler(func(nc *nats.Conn) { logger.Log.Warnf("reconnected to nats server %s with address %s in cluster %s!", nc.ConnectedServerName(), nc.ConnectedAddr(), nc.ConnectedClusterName()) @@ -78,7 +111,8 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O logger.Log.Errorf(err.Error()) } }), - nats.ConnectHandler(func(*nats.Conn) { + nats.ConnectHandler(func(nc *nats.Conn) { + logger.Log.Infof("connected to nats on %s", nc.ConnectedAddr()) connectedCh <- true }), ) @@ -104,8 +138,16 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O case <-connectedCh: return nc, nil case err := <-initialConnectErrorCh: + drainErr := drainAndClose(nc) + if drainErr != nil { + logger.Log.Warnf("failed to drain and close: %s", drainErr) + } return nil, err case <-time.After(maxConnTimeout * 2): + drainErr := drainAndClose(nc) + if drainErr != nil { + logger.Log.Warnf("failed to drain and close: %s", drainErr) + } return nil, fmt.Errorf("timeout setting up nats connection") } } diff --git a/cluster/nats_rpc_common_test.go b/cluster/nats_rpc_common_test.go index de4aac71..2dca0f46 100644 --- a/cluster/nats_rpc_common_test.go +++ b/cluster/nats_rpc_common_test.go @@ -25,7 +25,6 @@ import ( "testing" "time" - "github.com/nats-io/nats-server/v2/test" nats "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/topfreegames/pitaya/v2/helpers" @@ -47,8 +46,13 @@ func TestNatsRPCCommonGetChannel(t *testing.T) { func TestNatsRPCCommonSetupNatsConn(t *testing.T) { t.Parallel() + var conn *nats.Conn s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + drainAndClose(conn) + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) assert.NotNil(t, conn) @@ -56,158 +60,127 @@ func TestNatsRPCCommonSetupNatsConn(t *testing.T) { func TestNatsRPCCommonSetupNatsConnShouldError(t *testing.T) { t.Parallel() - conn, err := setupNatsConn("nats://localhost:1234", nil) + conn, err := setupNatsConn("nats://invalid:1234", nil) assert.Error(t, err) assert.Nil(t, conn) } func TestNatsRPCCommonCloseHandler(t *testing.T) { t.Parallel() + var conn *nats.Conn s := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + s.Shutdown() + s.WaitForShutdown() + }() dieChan := make(chan bool) + go func() { + value, ok := <-dieChan + assert.True(t, ok) + assert.True(t, value) + }() + conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nats.MaxReconnects(1), nats.ReconnectWait(1*time.Millisecond)) assert.NoError(t, err) assert.NotNil(t, conn) +} - s.Shutdown() - - value, ok := <-dieChan - assert.True(t, ok) - assert.True(t, value) +func TestNatsRPCCommonWaitReconnections(t *testing.T) { + t.Parallel() + var conn *nats.Conn + ts := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + ts.Shutdown() + ts.WaitForShutdown() + }() + + invalidAddr := "nats://invalid:4222" + validAddr := ts.ClientURL() + + urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr) + + // Setup connection with retry enabled + appDieCh := make(chan bool) + conn, err := setupNatsConn( + urls, + appDieCh, + nats.ReconnectWait(10*time.Millisecond), + nats.MaxReconnects(5), + nats.RetryOnFailedConnect(true), + ) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.True(t, conn.IsConnected()) } -func TestSetupNatsConnReconnection(t *testing.T) { - t.Run("waits for reconnection on initial failure", func(t *testing.T) { - // Use an invalid address first to force initial connection failure - invalidAddr := "nats://invalid:4222" - validAddr := "nats://localhost:4222" +func TestNatsRPCCommonDoNotBlockOnConnectionFail(t *testing.T) { + t.Parallel() + invalidAddr := "nats://invalid:4222" - urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr) + appDieCh := make(chan bool) + done := make(chan any) - go func() { - time.Sleep(50 * time.Millisecond) - ts := test.RunDefaultServer() - defer ts.Shutdown() - <-time.After(200 * time.Millisecond) - }() + var conn *nats.Conn + ts := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + ts.Shutdown() + ts.WaitForShutdown() + }() - // Setup connection with retry enabled - appDieCh := make(chan bool) + go func() { conn, err := setupNatsConn( - urls, + invalidAddr, appDieCh, nats.ReconnectWait(10*time.Millisecond), - nats.MaxReconnects(5), + nats.MaxReconnects(2), nats.RetryOnFailedConnect(true), ) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-appDieCh: + case <-done: + case <-time.After(250 * time.Millisecond): + t.Fail() + } +} - assert.NoError(t, err) - assert.NotNil(t, conn) - assert.True(t, conn.IsConnected()) - - conn.Close() - }) - - t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) { - invalidAddr := "nats://invalid:4222" - - appDieCh := make(chan bool) - done := make(chan any) - - ts := test.RunDefaultServer() - defer ts.Shutdown() - - go func() { - conn, err := setupNatsConn( - invalidAddr, - appDieCh, - nats.ReconnectWait(10*time.Millisecond), - nats.MaxReconnects(2), - nats.RetryOnFailedConnect(true), - ) - assert.Error(t, err) - assert.Nil(t, conn) - close(done) - close(appDieCh) - }() - - select { - case <-appDieCh: - case <-done: - case <-time.After(250 * time.Millisecond): - t.Fail() - } - }) - - t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) { - invalidAddr := "nats://invalid:4222" - - appDieCh := make(chan bool) - done := make(chan any) - - ts := test.RunDefaultServer() - defer ts.Shutdown() - - go func() { - conn, err := setupNatsConn(invalidAddr, appDieCh) - assert.Error(t, err) - assert.Nil(t, conn) - close(done) - close(appDieCh) - }() - - select { - case <-done: - case <-time.After(50 * time.Millisecond): - t.Fail() - } - }) - - t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) { - invalidAddr := "nats://invalid:4222" - - appDieCh := make(chan bool) - done := make(chan any) - - initialConnectionTimeout := time.Nanosecond - maxReconnectionAtetmpts := 1 - reconnectWait := time.Nanosecond - reconnectJitter := time.Nanosecond - maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout - maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout) - - maxTestTimeout := 100 * time.Millisecond - - // Assert that if it fails because of connection timeout the test will capture - assert.Greater(t, maxTestTimeout, maxReconnTimeout) - - ts := test.RunDefaultServer() - defer ts.Shutdown() - - go func() { - conn, err := setupNatsConn( - invalidAddr, - appDieCh, - nats.Timeout(initialConnectionTimeout), - nats.ReconnectWait(reconnectWait), - nats.MaxReconnects(maxReconnectionAtetmpts), - nats.ReconnectJitter(reconnectJitter, reconnectJitter), - nats.RetryOnFailedConnect(true), - ) - assert.Error(t, err) - assert.ErrorContains(t, err, "timeout setting up nats connection") - assert.Nil(t, conn) - close(done) - close(appDieCh) - }() - - select { - case <-done: - case <-time.After(maxTestTimeout): - t.Fail() - } - }) +func TestNatsRPCCommonFailWithoutAppDieChan(t *testing.T) { + t.Parallel() + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + var conn *nats.Conn + ts := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + ts.Shutdown() + ts.WaitForShutdown() + }() + + go func() { + conn, err := setupNatsConn(invalidAddr, appDieCh) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fail() + } } diff --git a/cluster/nats_rpc_server_test.go b/cluster/nats_rpc_server_test.go index 2661ada6..1e103e99 100644 --- a/cluster/nats_rpc_server_test.go +++ b/cluster/nats_rpc_server_test.go @@ -156,13 +156,17 @@ func TestNatsRPCServerOnSessionBind(t *testing.T) { rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn err = rpcServer.onSessionBind(context.Background(), mockSession) assert.NoError(t, err) assert.NotNil(t, rpcServer.userKickCh) + conn.Close() } func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) { @@ -171,7 +175,10 @@ func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -181,6 +188,7 @@ func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) { conn.Publish(GetBindBroadcastTopic(sv.Type), dt) msg := helpers.ShouldEventuallyReceive(t, rpcServer.GetBindingsChannel()).(*nats.Msg) assert.Equal(t, msg.Data, dt) + conn.Close() } func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) { @@ -189,7 +197,10 @@ func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -203,6 +214,7 @@ func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) { assert.NoError(t, err) msg := helpers.ShouldEventuallyReceive(t, rpcServer.getUserKickChannel()).(*protos.KickMsg) assert.Equal(t, msg.UserId, kick.UserId) + conn.Close() } func TestNatsRPCServerGetUserPushChannel(t *testing.T) { @@ -228,7 +240,10 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -251,6 +266,7 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) { helpers.ShouldEventuallyReceive(t, rpcServer.userPushCh) }) } + conn.Close() } func TestNatsRPCServerSubscribe(t *testing.T) { @@ -258,7 +274,10 @@ func TestNatsRPCServerSubscribe(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -281,6 +300,7 @@ func TestNatsRPCServerSubscribe(t *testing.T) { assert.Equal(t, table.msg, r.Data) }) } + conn.Close() } func TestNatsRPCServerHandleMessages(t *testing.T) { @@ -293,7 +313,10 @@ func TestNatsRPCServerHandleMessages(t *testing.T) { rpcServer, _ := NewNatsRPCServer(cfg, sv, mockMetricsReporters, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -324,6 +347,7 @@ func TestNatsRPCServerHandleMessages(t *testing.T) { assert.Equal(t, table.req.Msg.Id, r.Msg.Id) }) } + conn.Close() } func TestNatsRPCServerInitShouldFailIfConnFails(t *testing.T) { @@ -343,7 +367,10 @@ func TestNatsRPCServerInitShouldFailIfConnFails(t *testing.T) { func TestNatsRPCServerInit(t *testing.T) { s := helpers.GetTestNatsServer(t) ctrl := gomock.NewController(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer() @@ -381,7 +408,10 @@ func TestNatsRPCServerInit(t *testing.T) { func TestNatsRPCServerProcessBindings(t *testing.T) { ctrl := gomock.NewController(t) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer() @@ -424,7 +454,10 @@ func TestNatsRPCServerProcessBindings(t *testing.T) { func TestNatsRPCServerProcessPushes(t *testing.T) { s := helpers.GetTestNatsServer(t) ctrl := gomock.NewController(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer() @@ -459,7 +492,10 @@ func TestNatsRPCServerProcessPushes(t *testing.T) { func TestNatsRPCServerProcessKick(t *testing.T) { s := helpers.GetTestNatsServer(t) ctrl := gomock.NewController(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer()