diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 2853736c..8010ba0a 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -43,7 +43,7 @@ jobs: - name: Send coverage env: COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: ~/go/bin/goveralls -coverprofile=coverprofile.out -service=github + run: go run goveralls -coverprofile=coverprofile.out -service=github e2e-test-nats: name: Nats Test End to End runs-on: ubuntu-latest diff --git a/cluster/nats_rpc_client_test.go b/cluster/nats_rpc_client_test.go index 19eedb8e..72317d20 100644 --- a/cluster/nats_rpc_client_test.go +++ b/cluster/nats_rpc_client_test.go @@ -441,6 +441,7 @@ func TestNatsRPCClientCall(t *testing.T) { t.Run(table.name, func(t *testing.T) { ctrl := gomock.NewController(t) conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) + defer conn.Close() assert.NoError(t, err) sv2 := getServer() diff --git a/cluster/nats_rpc_common.go b/cluster/nats_rpc_common.go index 17d9b591..fa79e474 100644 --- a/cluster/nats_rpc_common.go +++ b/cluster/nats_rpc_common.go @@ -34,13 +34,46 @@ func getChannel(serverType, serverID string) string { return fmt.Sprintf("pitaya/servers/%s/%s", serverType, serverID) } +func drainAndClose(nc *nats.Conn) error { + if nc == nil { + return nil + } + // Drain connection (this will flush any pending messages and prevent new ones) + err := nc.Drain() + if err != nil { + logger.Log.Warnf("error draining nats connection: %v", err) + // Even if drain fails, try to close + nc.Close() + return err + } + + // Wait for drain to complete with timeout + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for nc.IsDraining() { + select { + case <-ticker.C: + continue + case <-timeout: + logger.Log.Warn("drain timeout exceeded, forcing close") + nc.Close() + return fmt.Errorf("drain timeout exceeded") + } + } + + // Close will happen automatically after drain completes + return nil +} + func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) { connectedCh := make(chan bool) initialConnectErrorCh := make(chan error) natsOptions := append( options, - nats.DisconnectErrHandler(func(_ *nats.Conn, err error) { - logger.Log.Warnf("disconnected from nats! Reason: %q\n", err) + nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { + logger.Log.Warnf("disconnected from nats (%s)! Reason: %q\n", nc.ConnectedAddr(), err) }), nats.ReconnectHandler(func(nc *nats.Conn) { logger.Log.Warnf("reconnected to nats server %s with address %s in cluster %s!", nc.ConnectedServerName(), nc.ConnectedAddr(), nc.ConnectedClusterName()) @@ -78,7 +111,8 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O logger.Log.Errorf(err.Error()) } }), - nats.ConnectHandler(func(*nats.Conn) { + nats.ConnectHandler(func(nc *nats.Conn) { + logger.Log.Infof("connected to nats on %s", nc.ConnectedAddr()) connectedCh <- true }), ) @@ -104,8 +138,16 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O case <-connectedCh: return nc, nil case err := <-initialConnectErrorCh: + drainErr := drainAndClose(nc) + if drainErr != nil { + logger.Log.Warnf("failed to drain and close: %s", drainErr) + } return nil, err case <-time.After(maxConnTimeout * 2): + drainErr := drainAndClose(nc) + if drainErr != nil { + logger.Log.Warnf("failed to drain and close: %s", drainErr) + } return nil, fmt.Errorf("timeout setting up nats connection") } } diff --git a/cluster/nats_rpc_common_test.go b/cluster/nats_rpc_common_test.go index de4aac71..2dca0f46 100644 --- a/cluster/nats_rpc_common_test.go +++ b/cluster/nats_rpc_common_test.go @@ -25,7 +25,6 @@ import ( "testing" "time" - "github.com/nats-io/nats-server/v2/test" nats "github.com/nats-io/nats.go" "github.com/stretchr/testify/assert" "github.com/topfreegames/pitaya/v2/helpers" @@ -47,8 +46,13 @@ func TestNatsRPCCommonGetChannel(t *testing.T) { func TestNatsRPCCommonSetupNatsConn(t *testing.T) { t.Parallel() + var conn *nats.Conn s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + drainAndClose(conn) + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) assert.NotNil(t, conn) @@ -56,158 +60,127 @@ func TestNatsRPCCommonSetupNatsConn(t *testing.T) { func TestNatsRPCCommonSetupNatsConnShouldError(t *testing.T) { t.Parallel() - conn, err := setupNatsConn("nats://localhost:1234", nil) + conn, err := setupNatsConn("nats://invalid:1234", nil) assert.Error(t, err) assert.Nil(t, conn) } func TestNatsRPCCommonCloseHandler(t *testing.T) { t.Parallel() + var conn *nats.Conn s := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + s.Shutdown() + s.WaitForShutdown() + }() dieChan := make(chan bool) + go func() { + value, ok := <-dieChan + assert.True(t, ok) + assert.True(t, value) + }() + conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), dieChan, nats.MaxReconnects(1), nats.ReconnectWait(1*time.Millisecond)) assert.NoError(t, err) assert.NotNil(t, conn) +} - s.Shutdown() - - value, ok := <-dieChan - assert.True(t, ok) - assert.True(t, value) +func TestNatsRPCCommonWaitReconnections(t *testing.T) { + t.Parallel() + var conn *nats.Conn + ts := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + ts.Shutdown() + ts.WaitForShutdown() + }() + + invalidAddr := "nats://invalid:4222" + validAddr := ts.ClientURL() + + urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr) + + // Setup connection with retry enabled + appDieCh := make(chan bool) + conn, err := setupNatsConn( + urls, + appDieCh, + nats.ReconnectWait(10*time.Millisecond), + nats.MaxReconnects(5), + nats.RetryOnFailedConnect(true), + ) + assert.NoError(t, err) + assert.NotNil(t, conn) + assert.True(t, conn.IsConnected()) } -func TestSetupNatsConnReconnection(t *testing.T) { - t.Run("waits for reconnection on initial failure", func(t *testing.T) { - // Use an invalid address first to force initial connection failure - invalidAddr := "nats://invalid:4222" - validAddr := "nats://localhost:4222" +func TestNatsRPCCommonDoNotBlockOnConnectionFail(t *testing.T) { + t.Parallel() + invalidAddr := "nats://invalid:4222" - urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr) + appDieCh := make(chan bool) + done := make(chan any) - go func() { - time.Sleep(50 * time.Millisecond) - ts := test.RunDefaultServer() - defer ts.Shutdown() - <-time.After(200 * time.Millisecond) - }() + var conn *nats.Conn + ts := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + ts.Shutdown() + ts.WaitForShutdown() + }() - // Setup connection with retry enabled - appDieCh := make(chan bool) + go func() { conn, err := setupNatsConn( - urls, + invalidAddr, appDieCh, nats.ReconnectWait(10*time.Millisecond), - nats.MaxReconnects(5), + nats.MaxReconnects(2), nats.RetryOnFailedConnect(true), ) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-appDieCh: + case <-done: + case <-time.After(250 * time.Millisecond): + t.Fail() + } +} - assert.NoError(t, err) - assert.NotNil(t, conn) - assert.True(t, conn.IsConnected()) - - conn.Close() - }) - - t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) { - invalidAddr := "nats://invalid:4222" - - appDieCh := make(chan bool) - done := make(chan any) - - ts := test.RunDefaultServer() - defer ts.Shutdown() - - go func() { - conn, err := setupNatsConn( - invalidAddr, - appDieCh, - nats.ReconnectWait(10*time.Millisecond), - nats.MaxReconnects(2), - nats.RetryOnFailedConnect(true), - ) - assert.Error(t, err) - assert.Nil(t, conn) - close(done) - close(appDieCh) - }() - - select { - case <-appDieCh: - case <-done: - case <-time.After(250 * time.Millisecond): - t.Fail() - } - }) - - t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) { - invalidAddr := "nats://invalid:4222" - - appDieCh := make(chan bool) - done := make(chan any) - - ts := test.RunDefaultServer() - defer ts.Shutdown() - - go func() { - conn, err := setupNatsConn(invalidAddr, appDieCh) - assert.Error(t, err) - assert.Nil(t, conn) - close(done) - close(appDieCh) - }() - - select { - case <-done: - case <-time.After(50 * time.Millisecond): - t.Fail() - } - }) - - t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) { - invalidAddr := "nats://invalid:4222" - - appDieCh := make(chan bool) - done := make(chan any) - - initialConnectionTimeout := time.Nanosecond - maxReconnectionAtetmpts := 1 - reconnectWait := time.Nanosecond - reconnectJitter := time.Nanosecond - maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout - maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout) - - maxTestTimeout := 100 * time.Millisecond - - // Assert that if it fails because of connection timeout the test will capture - assert.Greater(t, maxTestTimeout, maxReconnTimeout) - - ts := test.RunDefaultServer() - defer ts.Shutdown() - - go func() { - conn, err := setupNatsConn( - invalidAddr, - appDieCh, - nats.Timeout(initialConnectionTimeout), - nats.ReconnectWait(reconnectWait), - nats.MaxReconnects(maxReconnectionAtetmpts), - nats.ReconnectJitter(reconnectJitter, reconnectJitter), - nats.RetryOnFailedConnect(true), - ) - assert.Error(t, err) - assert.ErrorContains(t, err, "timeout setting up nats connection") - assert.Nil(t, conn) - close(done) - close(appDieCh) - }() - - select { - case <-done: - case <-time.After(maxTestTimeout): - t.Fail() - } - }) +func TestNatsRPCCommonFailWithoutAppDieChan(t *testing.T) { + t.Parallel() + invalidAddr := "nats://invalid:4222" + + appDieCh := make(chan bool) + done := make(chan any) + + var conn *nats.Conn + ts := helpers.GetTestNatsServer(t) + defer func() { + drainAndClose(conn) + ts.Shutdown() + ts.WaitForShutdown() + }() + + go func() { + conn, err := setupNatsConn(invalidAddr, appDieCh) + assert.Error(t, err) + assert.Nil(t, conn) + close(done) + close(appDieCh) + }() + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fail() + } } diff --git a/cluster/nats_rpc_server_test.go b/cluster/nats_rpc_server_test.go index 2661ada6..1e103e99 100644 --- a/cluster/nats_rpc_server_test.go +++ b/cluster/nats_rpc_server_test.go @@ -156,13 +156,17 @@ func TestNatsRPCServerOnSessionBind(t *testing.T) { rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn err = rpcServer.onSessionBind(context.Background(), mockSession) assert.NoError(t, err) assert.NotNil(t, rpcServer.userKickCh) + conn.Close() } func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) { @@ -171,7 +175,10 @@ func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -181,6 +188,7 @@ func TestNatsRPCServerSubscribeToBindingsChannel(t *testing.T) { conn.Publish(GetBindBroadcastTopic(sv.Type), dt) msg := helpers.ShouldEventuallyReceive(t, rpcServer.GetBindingsChannel()).(*nats.Msg) assert.Equal(t, msg.Data, dt) + conn.Close() } func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) { @@ -189,7 +197,10 @@ func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -203,6 +214,7 @@ func TestNatsRPCServerSubscribeUserKickChannel(t *testing.T) { assert.NoError(t, err) msg := helpers.ShouldEventuallyReceive(t, rpcServer.getUserKickChannel()).(*protos.KickMsg) assert.Equal(t, msg.UserId, kick.UserId) + conn.Close() } func TestNatsRPCServerGetUserPushChannel(t *testing.T) { @@ -228,7 +240,10 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -251,6 +266,7 @@ func TestNatsRPCServerSubscribeToUserMessages(t *testing.T) { helpers.ShouldEventuallyReceive(t, rpcServer.userPushCh) }) } + conn.Close() } func TestNatsRPCServerSubscribe(t *testing.T) { @@ -258,7 +274,10 @@ func TestNatsRPCServerSubscribe(t *testing.T) { sv := getServer() rpcServer, _ := NewNatsRPCServer(cfg, sv, nil, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -281,6 +300,7 @@ func TestNatsRPCServerSubscribe(t *testing.T) { assert.Equal(t, table.msg, r.Data) }) } + conn.Close() } func TestNatsRPCServerHandleMessages(t *testing.T) { @@ -293,7 +313,10 @@ func TestNatsRPCServerHandleMessages(t *testing.T) { rpcServer, _ := NewNatsRPCServer(cfg, sv, mockMetricsReporters, nil, nil) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() conn, err := setupNatsConn(fmt.Sprintf("nats://%s", s.Addr()), nil) assert.NoError(t, err) rpcServer.conn = conn @@ -324,6 +347,7 @@ func TestNatsRPCServerHandleMessages(t *testing.T) { assert.Equal(t, table.req.Msg.Id, r.Msg.Id) }) } + conn.Close() } func TestNatsRPCServerInitShouldFailIfConnFails(t *testing.T) { @@ -343,7 +367,10 @@ func TestNatsRPCServerInitShouldFailIfConnFails(t *testing.T) { func TestNatsRPCServerInit(t *testing.T) { s := helpers.GetTestNatsServer(t) ctrl := gomock.NewController(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer() @@ -381,7 +408,10 @@ func TestNatsRPCServerInit(t *testing.T) { func TestNatsRPCServerProcessBindings(t *testing.T) { ctrl := gomock.NewController(t) s := helpers.GetTestNatsServer(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer() @@ -424,7 +454,10 @@ func TestNatsRPCServerProcessBindings(t *testing.T) { func TestNatsRPCServerProcessPushes(t *testing.T) { s := helpers.GetTestNatsServer(t) ctrl := gomock.NewController(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer() @@ -459,7 +492,10 @@ func TestNatsRPCServerProcessPushes(t *testing.T) { func TestNatsRPCServerProcessKick(t *testing.T) { s := helpers.GetTestNatsServer(t) ctrl := gomock.NewController(t) - defer s.Shutdown() + defer func() { + s.Shutdown() + s.WaitForShutdown() + }() cfg := config.NewDefaultPitayaConfig().Cluster.RPC.Server.Nats cfg.Connect = fmt.Sprintf("nats://%s", s.Addr()) sv := getServer()