From d30e4aec4d33849b977005a03dc02a959126b142 Mon Sep 17 00:00:00 2001 From: Philip Potter Date: Thu, 10 Oct 2024 16:53:30 +0100 Subject: [PATCH] better retry logic This is a substantial refactoring: - I've changed the retry logic from making a recursive call to using a loop - I've started using the existing r.options.retryPolicy to define retry behavior - I've extracted a common streamEventsTo function, used by both streamTextTo and streamFilesTo, so that retry behavior is defined in one place --- stream.go | 171 +++++++++++++++++++++++++----------------------------- 1 file changed, 78 insertions(+), 93 deletions(-) diff --git a/stream.go b/stream.go index e64deaa..9b58b91 100644 --- a/stream.go +++ b/stream.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "strings" + "time" "unicode/utf8" "github.com/launchdarkly/eventsource" @@ -169,55 +170,17 @@ func (r *Client) StreamPredictionText(ctx context.Context, prediction *Predictio } func (r *Client) streamTextTo(ctx context.Context, writer *io.PipeWriter, url string, lastEventID string) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - writer.CloseWithError(fmt.Errorf("failed to create request: %w", err)) - return - } - req.Header.Set("Accept", "text/event-stream") - - if lastEventID != "" { - req.Header.Set("Last-Event-ID", lastEventID) - } + defer writer.Close() + ctx, cancel := context.WithCancel(ctx) + defer cancel() - resp, err := r.c.Do(req) - if err != nil { - writer.CloseWithError(fmt.Errorf("failed to send request: %w", err)) - return - } + ch := make(chan event) + go r.streamEventsTo(ctx, ch, url, lastEventID) - if resp.StatusCode != http.StatusOK { - writer.CloseWithError(fmt.Errorf("received invalid status code: %d", resp.StatusCode)) - return - } - defer resp.Body.Close() - decoder := eventsource.NewDecoder(resp.Body) - for { - event, err := decoder.Decode() + for e := range ch { + _, err := io.WriteString(writer, e.rawData) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - // retry (TODO: backoff policy?) - r.streamTextTo(ctx, writer, url, lastEventID) - return - } - writer.CloseWithError(fmt.Errorf("Failed to get token: %w", err)) - return - } - lastEventID = event.Id() - switch event.Event() { - case SSETypeOutput: - _, err := io.WriteString(writer, event.Data()) - if err != nil { - writer.CloseWithError(err) - return - } - case SSETypeDone: - writer.Close() - return - case SSETypeLogs: - // TODO - default: - writer.CloseWithError(fmt.Errorf("unknown event type %s", event.Event())) + writer.CloseWithError(err) return } } @@ -276,64 +239,86 @@ func (e *errWrapper) Close() error { return nil } -func (r *Client) streamFilesTo(ctx context.Context, out chan<- streaming.File, url string, lastEventID string) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - out <- fileError(err) - close(out) - return - } - req.Header.Set("Accept", "text/event-stream") +const maxRetries = 3 - if lastEventID != "" { - req.Header.Set("Last-Event-ID", lastEventID) +func (r *Client) streamFilesTo(ctx context.Context, out chan<- streaming.File, url string, lastEventID string) { + defer close(out) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + ch := make(chan event) + go r.streamEventsTo(ctx, ch, url, lastEventID) + + for e := range ch { + if strings.HasPrefix(e.rawData, "data:") { + out <- &dataURL{url: e.rawData} + } else if strings.HasPrefix(e.rawData, "http") { + out <- &httpURL{c: r.c, url: e.rawData} + } else { + out <- fileError(fmt.Errorf("Could not parse URL: %s", e.rawData)) + return + } } +} - resp, err := r.c.Do(req) - if err != nil { - out <- fileError(fmt.Errorf("failed to send request: %w", err)) - close(out) - return - } +type event struct { + rawData string + err error +} - if resp.StatusCode != http.StatusOK { - out <- fileError(fmt.Errorf("received invalid status code: %d", resp.StatusCode)) - close(out) - return - } - defer resp.Body.Close() - decoder := eventsource.NewDecoder(resp.Body) - for { - event, err := decoder.Decode() +func (r *Client) streamEventsTo(ctx context.Context, out chan<- event, url string, lastEventID string) { + defer close(out) +ATTEMPT: + for attempt := 0; attempt <= r.options.retryPolicy.maxRetries; attempt++ { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - // retry (TODO: backoff policy?) - r.streamFilesTo(ctx, out, url, lastEventID) - return - } - out <- fileError(fmt.Errorf("failed to get token: %w", err)) - close(out) + out <- event{err: err} return } - lastEventID = event.Id() - switch event.Event() { - case SSETypeOutput: - if strings.HasPrefix(event.Data(), "data:") { - out <- &dataURL{url: event.Data()} - } else if strings.HasPrefix(event.Data(), "http") { - out <- &httpURL{c: r.c, url: event.Data()} - } - case SSETypeDone: - close(out) + req.Header.Set("Accept", "text/event-stream") + + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) + } + + resp, err := r.c.Do(req) + if err != nil { + out <- event{err: fmt.Errorf("failed to send request: %w", err)} return - case SSETypeLogs: - // TODO - default: - out <- fileError(fmt.Errorf("unknown event type %s", event.Event())) + } + + if resp.StatusCode != http.StatusOK { + out <- event{err: fmt.Errorf("received invalid status code: %d", resp.StatusCode)} return } + defer resp.Body.Close() + decoder := eventsource.NewDecoder(resp.Body) + for { + e, err := decoder.Decode() + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + //retry + delay := r.options.retryPolicy.backoff.NextDelay(attempt) + time.Sleep(delay) + continue ATTEMPT + } + out <- event{err: fmt.Errorf("failed to get token: %w", err)} + return + } + lastEventID = e.Id() + switch e.Event() { + case SSETypeOutput: + out <- event{rawData: e.Data()} + case SSETypeDone: + return + case SSETypeLogs: + // TODO + default: + out <- event{err: fmt.Errorf("unknown event type %s", e.Event())} + return + } + } } - } func (r *Client) streamPrediction(ctx context.Context, prediction *Prediction, lastEvent *SSEEvent, sseChan chan SSEEvent, errChan chan error) {