From 8b7f5459ee31a222f7fd052f651d1cff67494114 Mon Sep 17 00:00:00 2001 From: Gabor Retvari Date: Wed, 7 Feb 2024 11:35:10 +0100 Subject: [PATCH] fix: Tight control of CDS server connection handler lifetime --- pkg/config/server/conn.go | 5 ++++- pkg/config/server/server.go | 17 +++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pkg/config/server/conn.go b/pkg/config/server/conn.go index ca460880..74a962f1 100644 --- a/pkg/config/server/conn.go +++ b/pkg/config/server/conn.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "sync" @@ -15,15 +16,17 @@ type Conn struct { *websocket.Conn Filter ConfigFilter patch ClientConfigPatcher + cancel context.CancelFunc readLock, writeLock sync.Mutex // for writemessage } // NewConn wraps a WebSocket connection. -func NewConn(conn *websocket.Conn, filter ConfigFilter, patch ClientConfigPatcher) *Conn { +func NewConn(conn *websocket.Conn, filter ConfigFilter, patch ClientConfigPatcher, cancel context.CancelFunc) *Conn { return &Conn{ Conn: conn, Filter: filter, patch: patch, + cancel: cancel, } } diff --git a/pkg/config/server/server.go b/pkg/config/server/server.go index bc25c732..1f3d15db 100644 --- a/pkg/config/server/server.go +++ b/pkg/config/server/server.go @@ -140,8 +140,12 @@ func (s *Server) RemoveClient(id string) { } } -func (s *Server) handleConn(ctx context.Context, wsConn *websocket.Conn, operationID string, filter ConfigFilter, patch ClientConfigPatcher) { - conn := NewConn(wsConn, filter, patch) +func (s *Server) handleConn(reqCtx context.Context, wsConn *websocket.Conn, operationID string, filter ConfigFilter, patch ClientConfigPatcher) { + // since wsConn is hijacked, reqCtx is unreliable in that it may not be canceled when the + // connection is closed, so we create our own connection context that we can cancel + // explicitly + ctx, cancel := context.WithCancel(reqCtx) + conn := NewConn(wsConn, filter, patch, cancel) s.conns.Upsert(conn) // a dummy reader that drops everything it receives: this must be there for the @@ -210,8 +214,7 @@ func (s *Server) sendConfig(conn *Conn, e *stnrv1.StunnerConfig) { } func (s *Server) sendJSONConfig(conn *Conn, json []byte) { - s.log.V(2).Info("sending configuration to client", "client", conn.Id(), - "config", string(json)) + s.log.V(2).Info("sending configuration to client", "client", conn.Id()) if err := conn.WriteMessage(websocket.TextMessage, json); err != nil { s.log.Error(err, "error sending config update", "client", conn.Id()) @@ -223,6 +226,12 @@ func (s *Server) closeConn(conn *Conn) { s.log.V(1).Info("closing client connection", "client", conn.Id()) conn.WriteMessage(websocket.CloseMessage, []byte{}) //nolint:errcheck + + if conn.cancel != nil { + conn.cancel() + conn.cancel = nil // make sure we can cancel multiple times + } + s.conns.Delete(conn) conn.Close() }