diff --git a/httpclient/client.go b/httpclient/client.go index 0a7593b..06789bf 100644 --- a/httpclient/client.go +++ b/httpclient/client.go @@ -110,8 +110,8 @@ func (c *ClientConfig) Build() (*Client, error) { } // TODO refactor redirects - if err := redirecthandler.SetupRedirectHandler(httpClient, c.FollowRedirects, c.MaxRedirects, c.Sugar); err != nil { - return nil, fmt.Errorf("Failed to set up redirect handler: %v", err) + if c.FollowRedirects { + redirecthandler.SetupRedirectHandler(httpClient, c.MaxRedirects, c.Sugar) } // TODO refactor concurrency diff --git a/redirecthandler/redirecthandler.go b/redirecthandler/redirecthandler.go index 61e4310..80d32a4 100644 --- a/redirecthandler/redirecthandler.go +++ b/redirecthandler/redirecthandler.go @@ -45,25 +45,31 @@ func (r *RedirectHandler) WithRedirectHandling(client *http.Client) { // checkRedirect implements the redirect handling logic. func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) error { - - // Ensure redirect history is always cleared to prevent memory leaks defer r.clearRedirectHistory(req) - // Non-idempotent methods handling + // Enforce max redirects + if len(via) >= r.MaxRedirects { + r.Logger.Warn("Maximum redirects reached", zap.Int("maxRedirects", r.MaxRedirects)) + return &MaxRedirectsError{MaxRedirects: r.MaxRedirects} + } + + // Return if disallowed method + // TODO why? if req.Method == http.MethodPost || req.Method == http.MethodPatch { r.Logger.Warn("Redirect attempted on non-idempotent method, not following", zap.String("method", req.Method)) - // Stop redirection and return the response as is return http.ErrUseLastResponse } // Check for cached permanent redirect - if urlString, ok := r.checkPermanentRedirect(req.URL.String()); ok && (req.Method == http.MethodGet || req.Method == http.MethodHead) { + // TODO why do we need to cache these? + urlString, ok := r.checkPermanentRedirect(req.URL.String()) + if ok && (req.Method == http.MethodGet || req.Method == http.MethodHead) { parsedURL, err := url.Parse(urlString) if err != nil { + // TODO is there ever a time where the cached one will be invalid? r.Logger.Error("Failed to parse URL from cache", zap.String("url", urlString), zap.Error(err)) - // Continue with the original URL since the cached URL is invalid } else { - req.URL = parsedURL // Use cached redirect location + req.URL = parsedURL r.Logger.Info("Using cached permanent redirect", zap.String("originalURL", urlString), zap.String("redirectURL", parsedURL.String())) return nil } @@ -73,17 +79,11 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) r.RedirectHistories[req] = append(r.RedirectHistories[req], req.URL) // Check for redirect loops by analyzing the history - if hasLoop(r.RedirectHistories[req]) { + if redirectLoop(r.RedirectHistories[req]) { r.Logger.Error("Redirect loop detected", zap.Any("redirectHistory", r.RedirectHistories[req])) return fmt.Errorf("redirect loop detected: %v", r.RedirectHistories[req]) } - // Enforce max redirects - if len(via) >= r.MaxRedirects { - r.Logger.Warn("Maximum redirects reached", zap.Int("maxRedirects", r.MaxRedirects)) - return &MaxRedirectsError{MaxRedirects: r.MaxRedirects} - } - lastResponse := via[len(via)-1].Response if lastResponse.StatusCode == http.StatusPermanentRedirect || lastResponse.StatusCode == http.StatusTemporaryRedirect { location, err := lastResponse.Location() @@ -98,40 +98,36 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) return err } - // Apply security measures for cross-domain redirects if newReqURL.Host != req.URL.Host { r.secureRequest(req) } - // Cache permanent redirects if lastResponse.StatusCode == http.StatusPermanentRedirect { r.cachePermanentRedirect(req.URL.String(), newReqURL.String()) } - // Special handling for 303 See Other if lastResponse.StatusCode == http.StatusSeeOther { r.adjustForSeeOther(req) } r.Logger.Info("Redirecting request", zap.String("originalURL", req.URL.String()), zap.String("newURL", newReqURL.String()), zap.Int("redirectCount", len(via))) - req.URL = newReqURL // Update request URL to follow the redirect + req.URL = newReqURL return nil } // Clear redirect history if redirect is successful if len(via) > 0 && lastResponse.StatusCode >= 200 && lastResponse.StatusCode < 400 { - // Clear history for the redirected request redirectedReq := via[len(via)-1] r.clearRedirectHistory(redirectedReq) } - return http.ErrUseLastResponse // No further action required if not a redirect status code + return http.ErrUseLastResponse } // resolveRedirectURL resolves the redirect location URL against the current request URL. func (r *RedirectHandler) resolveRedirectURL(reqURL *url.URL, redirectURL *url.URL) (*url.URL, error) { if !redirectURL.IsAbs() { - redirectURL.Scheme = reqURL.Scheme // Preserve the scheme + redirectURL.Scheme = reqURL.Scheme } return redirectURL, nil } @@ -189,21 +185,30 @@ func (r *RedirectHandler) checkPermanentRedirect(originalURL string) (string, bo return url, exists } -// hasLoop checks if there's a loop in the redirect history. -func hasLoop(history []*url.URL) bool { - urlSet := make(map[string]struct{}) - for _, url := range history { - if _, exists := urlSet[url.String()]; exists { - return true // Loop detected +// redirectLoop checks if there's a loop in the redirect history. +func redirectLoop(history []*url.URL) bool { + var urls []string + for _, v := range history { + urls = append(urls, v.String()) + } + + // if duplicates found at different indexes in loop. I don't think it's pretty but it works. + for i, j := range urls { + for k, l := range urls { + if i != k { + if j == l { + return true + } + } } - urlSet[url.String()] = struct{}{} } + return false } // clearRedirectHistory clears the redirect history for a given request to prevent memory leaks. func (r *RedirectHandler) clearRedirectHistory(req *http.Request) { - r.VisitedURLsMutex.Lock() // Use the appropriate mutex to synchronize access to RedirectHistories + r.VisitedURLsMutex.Lock() delete(r.RedirectHistories, req) r.VisitedURLsMutex.Unlock() } @@ -217,16 +222,9 @@ func (r *RedirectHandler) GetRedirectHistory(req *http.Request) []*url.URL { } // SetupRedirectHandler configures the HTTP client for redirect handling based on the client configuration. -func SetupRedirectHandler(client *http.Client, followRedirects bool, maxRedirects int, log *zap.SugaredLogger) error { - if followRedirects { - if maxRedirects < 1 { - log.Error("Invalid maxRedirects value", zap.Int("maxRedirects", maxRedirects)) - return fmt.Errorf("invalid maxRedirects value: %d", maxRedirects) - } +func SetupRedirectHandler(client *http.Client, maxRedirects int, log *zap.SugaredLogger) { + redirectHandler := NewRedirectHandler(log, maxRedirects) + redirectHandler.WithRedirectHandling(client) + log.Info("Redirect handling enabled", zap.Int("MaxRedirects", maxRedirects)) - redirectHandler := NewRedirectHandler(log, maxRedirects) - redirectHandler.WithRedirectHandling(client) - log.Info("Redirect handling enabled", zap.Int("MaxRedirects", maxRedirects)) - } - return nil }