Skip to content

Commit

Permalink
Added comments, removed comments stating obvious things and added som…
Browse files Browse the repository at this point in the history
…e questions
  • Loading branch information
thejoeker12 committed Jul 9, 2024
1 parent 224218c commit 666e22f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 42 deletions.
4 changes: 2 additions & 2 deletions httpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ func (c *ClientConfig) Build() (*Client, error) {
}

// TODO refactor redirects
if err := redirecthandler.SetupRedirectHandler(httpClient, c.FollowRedirects, c.MaxRedirects, c.Sugar); err != nil {
return nil, fmt.Errorf("Failed to set up redirect handler: %v", err)
if c.FollowRedirects {
redirecthandler.SetupRedirectHandler(httpClient, c.MaxRedirects, c.Sugar)
}

// TODO refactor concurrency
Expand Down
78 changes: 38 additions & 40 deletions redirecthandler/redirecthandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,31 @@ func (r *RedirectHandler) WithRedirectHandling(client *http.Client) {

// checkRedirect implements the redirect handling logic.
func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) error {

// Ensure redirect history is always cleared to prevent memory leaks
defer r.clearRedirectHistory(req)

// Non-idempotent methods handling
// Enforce max redirects
if len(via) >= r.MaxRedirects {
r.Logger.Warn("Maximum redirects reached", zap.Int("maxRedirects", r.MaxRedirects))
return &MaxRedirectsError{MaxRedirects: r.MaxRedirects}
}

// Return if disallowed method
// TODO why?
if req.Method == http.MethodPost || req.Method == http.MethodPatch {
r.Logger.Warn("Redirect attempted on non-idempotent method, not following", zap.String("method", req.Method))
// Stop redirection and return the response as is
return http.ErrUseLastResponse
}

// Check for cached permanent redirect
if urlString, ok := r.checkPermanentRedirect(req.URL.String()); ok && (req.Method == http.MethodGet || req.Method == http.MethodHead) {
// TODO why do we need to cache these?
urlString, ok := r.checkPermanentRedirect(req.URL.String())
if ok && (req.Method == http.MethodGet || req.Method == http.MethodHead) {
parsedURL, err := url.Parse(urlString)
if err != nil {
// TODO is there ever a time where the cached one will be invalid?
r.Logger.Error("Failed to parse URL from cache", zap.String("url", urlString), zap.Error(err))
// Continue with the original URL since the cached URL is invalid
} else {
req.URL = parsedURL // Use cached redirect location
req.URL = parsedURL
r.Logger.Info("Using cached permanent redirect", zap.String("originalURL", urlString), zap.String("redirectURL", parsedURL.String()))
return nil
}
Expand All @@ -73,17 +79,11 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request)
r.RedirectHistories[req] = append(r.RedirectHistories[req], req.URL)

// Check for redirect loops by analyzing the history
if hasLoop(r.RedirectHistories[req]) {
if redirectLoop(r.RedirectHistories[req]) {
r.Logger.Error("Redirect loop detected", zap.Any("redirectHistory", r.RedirectHistories[req]))
return fmt.Errorf("redirect loop detected: %v", r.RedirectHistories[req])
}

// Enforce max redirects
if len(via) >= r.MaxRedirects {
r.Logger.Warn("Maximum redirects reached", zap.Int("maxRedirects", r.MaxRedirects))
return &MaxRedirectsError{MaxRedirects: r.MaxRedirects}
}

lastResponse := via[len(via)-1].Response
if lastResponse.StatusCode == http.StatusPermanentRedirect || lastResponse.StatusCode == http.StatusTemporaryRedirect {
location, err := lastResponse.Location()
Expand All @@ -98,40 +98,36 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request)
return err
}

// Apply security measures for cross-domain redirects
if newReqURL.Host != req.URL.Host {
r.secureRequest(req)
}

// Cache permanent redirects
if lastResponse.StatusCode == http.StatusPermanentRedirect {
r.cachePermanentRedirect(req.URL.String(), newReqURL.String())
}

// Special handling for 303 See Other
if lastResponse.StatusCode == http.StatusSeeOther {
r.adjustForSeeOther(req)
}

r.Logger.Info("Redirecting request", zap.String("originalURL", req.URL.String()), zap.String("newURL", newReqURL.String()), zap.Int("redirectCount", len(via)))
req.URL = newReqURL // Update request URL to follow the redirect
req.URL = newReqURL
return nil
}

// Clear redirect history if redirect is successful
if len(via) > 0 && lastResponse.StatusCode >= 200 && lastResponse.StatusCode < 400 {
// Clear history for the redirected request
redirectedReq := via[len(via)-1]
r.clearRedirectHistory(redirectedReq)
}

return http.ErrUseLastResponse // No further action required if not a redirect status code
return http.ErrUseLastResponse
}

// resolveRedirectURL resolves the redirect location URL against the current request URL.
func (r *RedirectHandler) resolveRedirectURL(reqURL *url.URL, redirectURL *url.URL) (*url.URL, error) {
if !redirectURL.IsAbs() {
redirectURL.Scheme = reqURL.Scheme // Preserve the scheme
redirectURL.Scheme = reqURL.Scheme
}
return redirectURL, nil
}
Expand Down Expand Up @@ -189,21 +185,30 @@ func (r *RedirectHandler) checkPermanentRedirect(originalURL string) (string, bo
return url, exists
}

// hasLoop checks if there's a loop in the redirect history.
func hasLoop(history []*url.URL) bool {
urlSet := make(map[string]struct{})
for _, url := range history {
if _, exists := urlSet[url.String()]; exists {
return true // Loop detected
// redirectLoop checks if there's a loop in the redirect history.
func redirectLoop(history []*url.URL) bool {
var urls []string
for _, v := range history {
urls = append(urls, v.String())
}

// if duplicates found at different indexes in loop. I don't think it's pretty but it works.
for i, j := range urls {
for k, l := range urls {
if i != k {
if j == l {
return true
}
}
}
urlSet[url.String()] = struct{}{}
}

return false
}

// clearRedirectHistory clears the redirect history for a given request to prevent memory leaks.
func (r *RedirectHandler) clearRedirectHistory(req *http.Request) {
r.VisitedURLsMutex.Lock() // Use the appropriate mutex to synchronize access to RedirectHistories
r.VisitedURLsMutex.Lock()
delete(r.RedirectHistories, req)
r.VisitedURLsMutex.Unlock()
}
Expand All @@ -217,16 +222,9 @@ func (r *RedirectHandler) GetRedirectHistory(req *http.Request) []*url.URL {
}

// SetupRedirectHandler configures the HTTP client for redirect handling based on the client configuration.
func SetupRedirectHandler(client *http.Client, followRedirects bool, maxRedirects int, log *zap.SugaredLogger) error {
if followRedirects {
if maxRedirects < 1 {
log.Error("Invalid maxRedirects value", zap.Int("maxRedirects", maxRedirects))
return fmt.Errorf("invalid maxRedirects value: %d", maxRedirects)
}
func SetupRedirectHandler(client *http.Client, maxRedirects int, log *zap.SugaredLogger) {
redirectHandler := NewRedirectHandler(log, maxRedirects)
redirectHandler.WithRedirectHandling(client)
log.Info("Redirect handling enabled", zap.Int("MaxRedirects", maxRedirects))

redirectHandler := NewRedirectHandler(log, maxRedirects)
redirectHandler.WithRedirectHandling(client)
log.Info("Redirect handling enabled", zap.Int("MaxRedirects", maxRedirects))
}
return nil
}

0 comments on commit 666e22f

Please sign in to comment.