diff --git a/Caddyfile b/Caddyfile index 61f15a5..0cfc7c5 100644 --- a/Caddyfile +++ b/Caddyfile @@ -8,7 +8,7 @@ log { output stdout format console - level INFO + level DEBUG } handle { @@ -47,7 +47,7 @@ # rule_file rules/wordpress.json ip_blacklist_file ip_blacklist.txt dns_blacklist_file dns_blacklist.txt - log_severity info + log_severity debug log_json log_path debug.json # redact_sensitive_data diff --git a/caddywaf.go b/caddywaf.go index 5574c78..2dc99e8 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -570,168 +570,3 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { } return m.configLoader.UnmarshalCaddyfile(d, m) } - -func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) { - m.logger.Debug("Starting phase evaluation", - zap.Int("phase", phase), - zap.String("source_ip", r.RemoteAddr), - zap.String("user_agent", r.UserAgent()), - ) - - if phase == 1 && m.CountryBlock.Enabled { - m.logger.Debug("Starting country blocking phase") - blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP) - if err != nil { - m.logRequest(zapcore.ErrorLevel, "Failed to check country block", - r, - zap.Error(err), - ) - m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, - zap.String("message", "Request blocked due to internal error"), - ) - m.logger.Debug("Country blocking phase completed - blocked due to error") - return - } else if blocked { - m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, - zap.String("message", "Request blocked by country"), - ) - m.logger.Debug("Country blocking phase completed - blocked by country") - return - } - m.logger.Debug("Country blocking phase completed - not blocked") - } - - if phase == 1 && m.rateLimiter != nil { - m.logger.Debug("Starting rate limiting phase") - ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here - path := r.URL.Path // Get the request path - if m.rateLimiter.isRateLimited(ip, path) { - m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, - zap.String("message", "Request blocked by rate limit"), - ) - m.logger.Debug("Rate limiting phase completed - blocked by rate limit") - return - } - m.logger.Debug("Rate limiting phase completed - not blocked") - } - - if phase == 1 && m.isIPBlacklisted(r.RemoteAddr) { - m.logger.Debug("Starting IP blacklist phase") - m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr, - zap.String("message", "Request blocked by IP blacklist"), - ) - m.logger.Debug("IP blacklist phase completed - blocked") - return - } - - if phase == 1 && m.isDNSBlacklisted(r.Host) { - m.logger.Debug("Starting DNS blacklist phase") - m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host, - zap.String("message", "Request blocked by DNS blacklist"), - zap.String("host", r.Host), - ) - m.logger.Debug("DNS blacklist phase completed - blocked") - return - } - - rules, ok := m.Rules[phase] - if !ok { - m.logger.Debug("No rules found for phase", zap.Int("phase", phase)) - return - } - - m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules))) - for _, rule := range rules { - m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets))) - - // Use the custom type as the key - ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID) - r = r.WithContext(ctx) - - for _, target := range rule.Targets { - m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", string(rule.ID))) - var value string - var err error - - if phase == 3 || phase == 4 { - if recorder, ok := w.(*responseRecorder); ok { - value, err = m.extractValue(target, r, recorder) - } else { - m.logger.Error("response recorder is not available in phase 3 or 4 when required") - value, err = m.extractValue(target, r, nil) - } - } else { - value, err = m.extractValue(target, r, nil) - } - - if err != nil { - m.logger.Debug("Failed to extract value for target, skipping rule for this target", - zap.String("target", target), - zap.String("rule_id", string(rule.ID)), - zap.Error(err), - ) - continue - } - - m.logger.Debug("Extracted value", - zap.String("rule_id", string(rule.ID)), - zap.String("target", target), - zap.String("value", value), - ) - - if rule.regex.MatchString(value) { - m.logger.Debug("Rule matched", - zap.String("rule_id", string(rule.ID)), - zap.String("target", target), - zap.String("value", value), - ) - if phase == 3 || phase == 4 { - if recorder, ok := w.(*responseRecorder); ok { - if !m.processRuleMatch(recorder, r, &rule, value, state) { - return // Stop processing if the rule match indicates blocking - } - } else { - if !m.processRuleMatch(w, r, &rule, value, state) { - return // Stop processing if the rule match indicates blocking - } - } - } else { - if !m.processRuleMatch(w, r, &rule, value, state) { - return // Stop processing if the rule match indicates blocking - } - } - if state.Blocked || state.ResponseWritten { - m.logger.Debug("Rule evaluation completed early due to blocking or response written", zap.Int("phase", phase), zap.String("rule_id", string(rule.ID))) - return - } - } else { - m.logger.Debug("Rule did not match", - zap.String("rule_id", string(rule.ID)), - zap.String("target", target), - zap.String("value", value), - ) - } - } - } - m.logger.Debug("Rule evaluation completed for phase", zap.Int("phase", phase)) - - if phase == 3 { - m.logger.Debug("Starting response headers phase") - if _, ok := w.(*responseRecorder); ok { - m.logger.Debug("Response headers phase completed") - } - } - - if phase == 4 { - m.logger.Debug("Starting response body phase") - if _, ok := w.(*responseRecorder); ok { - m.logger.Debug("Response body phase completed") - } - } - - m.logger.Debug("Completed phase evaluation", - zap.Int("phase", phase), - zap.Int("total_score", state.TotalScore), - zap.Int("anomaly_threshold", m.AnomalyThreshold), - ) -} diff --git a/handler.go b/handler.go index b6d2b2c..4ee9a35 100644 --- a/handler.go +++ b/handler.go @@ -3,10 +3,12 @@ package caddywaf import ( "context" "net/http" + "strings" "github.com/caddyserver/caddy/v2/modules/caddyhttp" "github.com/google/uuid" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) // ServeHTTP implements caddyhttp.Handler. @@ -194,3 +196,193 @@ func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecor m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", r.Context().Value("logID").(string))) } } + +func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) { + m.logger.Debug("Starting phase evaluation", + zap.Int("phase", phase), + zap.String("source_ip", r.RemoteAddr), + zap.String("user_agent", r.UserAgent()), + ) + + if phase == 1 && m.CountryBlock.Enabled { + m.logger.Debug("Starting country blocking phase") + blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check country block", + r, + zap.Error(err), + ) + m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked due to internal error"), + ) + m.logger.Debug("Country blocking phase completed - blocked due to error") + return + } else if blocked { + m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr, + zap.String("message", "Request blocked by country"), + ) + m.logger.Debug("Country blocking phase completed - blocked by country") + return + } + m.logger.Debug("Country blocking phase completed - not blocked") + } + + if phase == 1 && m.rateLimiter != nil { + m.logger.Debug("Starting rate limiting phase") + ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here + path := r.URL.Path // Get the request path + if m.rateLimiter.isRateLimited(ip, path) { + m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr, + zap.String("message", "Request blocked by rate limit"), + ) + m.logger.Debug("Rate limiting phase completed - blocked by rate limit") + return + } + m.logger.Debug("Rate limiting phase completed - not blocked") + } + + if phase == 1 { + m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) //Added log for checking before to isIPBlacklisted call + xForwardedFor := r.Header.Get("X-Forwarded-For") + if xForwardedFor != "" { + ips := strings.Split(xForwardedFor, ",") + if len(ips) > 0 { + firstIP := strings.TrimSpace(ips[0]) + m.logger.Debug("Checking IP blacklist with X-Forwarded-For", zap.String("remote_addr_xff", firstIP), zap.String("r.RemoteAddr", r.RemoteAddr)) + if m.isIPBlacklisted(firstIP) { + m.logger.Debug("Starting IP blacklist phase") + m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP, + zap.String("message", "Request blocked by IP blacklist"), + ) + m.logger.Debug("IP blacklist phase completed - blocked") + return + } + } else { + m.logger.Debug("X-Forwarded-For header present but empty or invalid") + + } + + } else { + m.logger.Debug("X-Forwarded-For header not present using r.RemoteAddr") + if m.isIPBlacklisted(r.RemoteAddr) { + m.logger.Debug("Starting IP blacklist phase") + m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr, + zap.String("message", "Request blocked by IP blacklist"), + ) + m.logger.Debug("IP blacklist phase completed - blocked") + return + } + } + } + + if phase == 1 && m.isDNSBlacklisted(r.Host) { + m.logger.Debug("Starting DNS blacklist phase") + m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host, + zap.String("message", "Request blocked by DNS blacklist"), + zap.String("host", r.Host), + ) + m.logger.Debug("DNS blacklist phase completed - blocked") + return + } + + rules, ok := m.Rules[phase] + if !ok { + m.logger.Debug("No rules found for phase", zap.Int("phase", phase)) + return + } + + m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules))) + for _, rule := range rules { + m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets))) + + // Use the custom type as the key + ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID) + r = r.WithContext(ctx) + + for _, target := range rule.Targets { + m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", string(rule.ID))) + var value string + var err error + + if phase == 3 || phase == 4 { + if recorder, ok := w.(*responseRecorder); ok { + value, err = m.extractValue(target, r, recorder) + } else { + m.logger.Error("response recorder is not available in phase 3 or 4 when required") + value, err = m.extractValue(target, r, nil) + } + } else { + value, err = m.extractValue(target, r, nil) + } + + if err != nil { + m.logger.Debug("Failed to extract value for target, skipping rule for this target", + zap.String("target", target), + zap.String("rule_id", string(rule.ID)), + zap.Error(err), + ) + continue + } + + m.logger.Debug("Extracted value", + zap.String("rule_id", string(rule.ID)), + zap.String("target", target), + zap.String("value", value), + ) + + if rule.regex.MatchString(value) { + m.logger.Debug("Rule matched", + zap.String("rule_id", string(rule.ID)), + zap.String("target", target), + zap.String("value", value), + ) + if phase == 3 || phase == 4 { + if recorder, ok := w.(*responseRecorder); ok { + if !m.processRuleMatch(recorder, r, &rule, value, state) { + return // Stop processing if the rule match indicates blocking + } + } else { + if !m.processRuleMatch(w, r, &rule, value, state) { + return // Stop processing if the rule match indicates blocking + } + } + } else { + if !m.processRuleMatch(w, r, &rule, value, state) { + return // Stop processing if the rule match indicates blocking + } + } + if state.Blocked || state.ResponseWritten { + m.logger.Debug("Rule evaluation completed early due to blocking or response written", zap.Int("phase", phase), zap.String("rule_id", string(rule.ID))) + return + } + } else { + m.logger.Debug("Rule did not match", + zap.String("rule_id", string(rule.ID)), + zap.String("target", target), + zap.String("value", value), + ) + } + } + } + m.logger.Debug("Rule evaluation completed for phase", zap.Int("phase", phase)) + + if phase == 3 { + m.logger.Debug("Starting response headers phase") + if _, ok := w.(*responseRecorder); ok { + m.logger.Debug("Response headers phase completed") + } + } + + if phase == 4 { + m.logger.Debug("Starting response body phase") + if _, ok := w.(*responseRecorder); ok { + m.logger.Debug("Response body phase completed") + } + } + + m.logger.Debug("Completed phase evaluation", + zap.Int("phase", phase), + zap.Int("total_score", state.TotalScore), + zap.Int("anomaly_threshold", m.AnomalyThreshold), + ) +} diff --git a/rules.json b/rules.json index 18b3188..209a3b9 100644 --- a/rules.json +++ b/rules.json @@ -71,18 +71,6 @@ "score": 9, "description": "Block SQL injection, XSS, and path traversal attempts in headers. Improved pattern matching." }, - { - "id": "header-suspicious-x-forwarded-for", - "phase": 1, - "pattern": "(?:127\\.0\\.0\\.1|10\\.|172\\.(?:1[6-9]|2\\d|3[01])\\.|192\\.168\\.|169\\.254\\.|::1)", - "targets": [ - "HEADERS:X-Forwarded-For" - ], - "severity": "MEDIUM", - "action": "block", - "score": 6, - "description": "Block requests with potentially internal IPs in X-Forwarded-For. Added more internal IP ranges." - }, { "id": "http-request-smuggling", "phase": 1,