From dcf71c4cc85be7905ed1a31e3f75d1d6687109e1 Mon Sep 17 00:00:00 2001 From: francisco souza <108725+fsouza@users.noreply.github.com> Date: Thu, 18 Aug 2022 23:32:23 -0400 Subject: [PATCH] event: fix connection leak Make sure we close the connection when disabling event monitoring. --- event.go | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/event.go b/event.go index 024b4ecc..ce1fb502 100644 --- a/event.go +++ b/event.go @@ -93,6 +93,7 @@ type eventMonitoringState struct { C chan *APIEvents errC chan error listeners []chan<- *APIEvents + closeConn func() } const ( @@ -229,6 +230,11 @@ func (eventState *eventMonitoringState) disableEventMonitoring() { eventState.enabled = false close(eventState.C) close(eventState.errC) + + if eventState.closeConn != nil { + eventState.closeConn() + eventState.closeConn = nil + } } } @@ -290,7 +296,7 @@ func (eventState *eventMonitoringState) connectWithRetry(c *Client, opts EventsO eventChan := eventState.C errChan := eventState.errC eventState.RUnlock() - err := c.eventHijack(opts, atomic.LoadInt64(&eventState.lastSeen), eventChan, errChan) + closeConn, err := c.eventHijack(opts, atomic.LoadInt64(&eventState.lastSeen), eventChan, errChan) for ; err != nil && retries < maxMonitorConnRetries; retries++ { waitTime := int64(retryInitialWaitTime * math.Pow(2, float64(retries))) time.Sleep(time.Duration(waitTime) * time.Millisecond) @@ -298,8 +304,11 @@ func (eventState *eventMonitoringState) connectWithRetry(c *Client, opts EventsO eventChan = eventState.C errChan = eventState.errC eventState.RUnlock() - err = c.eventHijack(opts, atomic.LoadInt64(&eventState.lastSeen), eventChan, errChan) + closeConn, err = c.eventHijack(opts, atomic.LoadInt64(&eventState.lastSeen), eventChan, errChan) } + eventState.Lock() + defer eventState.Unlock() + eventState.closeConn = closeConn return err } @@ -343,7 +352,7 @@ func (eventState *eventMonitoringState) updateLastSeen(e *APIEvents) { } } -func (c *Client) eventHijack(opts EventsOptions, startTime int64, eventChan chan *APIEvents, errChan chan error) error { +func (c *Client) eventHijack(opts EventsOptions, startTime int64, eventChan chan *APIEvents, errChan chan error) (closeConn func(), err error) { // on reconnect override initial Since with last event seen time if startTime != 0 { opts.Since = strconv.FormatInt(startTime, 10) @@ -356,37 +365,38 @@ func (c *Client) eventHijack(opts EventsOptions, startTime int64, eventChan chan address = c.endpointURL.Host } var dial net.Conn - var err error if c.TLSConfig == nil { dial, err = c.Dialer.Dial(protocol, address) } else { netDialer, ok := c.Dialer.(*net.Dialer) if !ok { - return ErrTLSNotSupported + return nil, ErrTLSNotSupported } dial, err = tlsDialWithDialer(netDialer, protocol, address, c.TLSConfig) } if err != nil { - return err + return nil, err } //lint:ignore SA1019 the alternative doesn't quite work, so keep using the deprecated thing. conn := httputil.NewClientConn(dial, nil) req, err := http.NewRequest(http.MethodGet, uri, nil) if err != nil { - return err + return nil, err } res, err := conn.Do(req) if err != nil { - return err + return nil, err } + + keepRunning := int32(1) //lint:ignore SA1019 the alternative doesn't quite work, so keep using the deprecated thing. go func(res *http.Response, conn *httputil.ClientConn) { defer conn.Close() defer res.Body.Close() decoder := json.NewDecoder(res.Body) - for { + for atomic.LoadInt32(&keepRunning) == 1 { var event APIEvents - if err = decoder.Decode(&event); err != nil { + if err := decoder.Decode(&event); err != nil { if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { c.eventMonitor.RLock() if c.eventMonitor.enabled && c.eventMonitor.C == eventChan { @@ -409,7 +419,9 @@ func (c *Client) eventHijack(opts EventsOptions, startTime int64, eventChan chan c.eventMonitor.RUnlock() } }(res, conn) - return nil + return func() { + atomic.StoreInt32(&keepRunning, 0) + }, nil } // transformEvent takes an event and determines what version it is from