Skip to content

Commit

Permalink
IP blacklisting fixed. IP & DNS metrics works properly now. Rule head…
Browse files Browse the repository at this point in the history
…er-suspicious-x-forwarded-for temporary removed due to excessive false positive ratio.
  • Loading branch information
fabriziosalmi committed Jan 26, 2025
1 parent 86c96b9 commit ee5fc71
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 179 deletions.
4 changes: 2 additions & 2 deletions Caddyfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
log {
output stdout
format console
level INFO
level DEBUG
}

handle {
Expand Down Expand Up @@ -47,7 +47,7 @@
# rule_file rules/wordpress.json
ip_blacklist_file ip_blacklist.txt
dns_blacklist_file dns_blacklist.txt
log_severity info
log_severity debug
log_json
log_path debug.json
# redact_sensitive_data
Expand Down
165 changes: 0 additions & 165 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,168 +570,3 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
}
return m.configLoader.UnmarshalCaddyfile(d, m)
}

func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) {
m.logger.Debug("Starting phase evaluation",
zap.Int("phase", phase),
zap.String("source_ip", r.RemoteAddr),
zap.String("user_agent", r.UserAgent()),
)

if phase == 1 && m.CountryBlock.Enabled {
m.logger.Debug("Starting country blocking phase")
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check country block",
r,
zap.Error(err),
)
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked due to internal error"),
)
m.logger.Debug("Country blocking phase completed - blocked due to error")
return
} else if blocked {
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked by country"),
)
m.logger.Debug("Country blocking phase completed - blocked by country")
return
}
m.logger.Debug("Country blocking phase completed - not blocked")
}

if phase == 1 && m.rateLimiter != nil {
m.logger.Debug("Starting rate limiting phase")
ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here
path := r.URL.Path // Get the request path
if m.rateLimiter.isRateLimited(ip, path) {
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
zap.String("message", "Request blocked by rate limit"),
)
m.logger.Debug("Rate limiting phase completed - blocked by rate limit")
return
}
m.logger.Debug("Rate limiting phase completed - not blocked")
}

if phase == 1 && m.isIPBlacklisted(r.RemoteAddr) {
m.logger.Debug("Starting IP blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}

if phase == 1 && m.isDNSBlacklisted(r.Host) {
m.logger.Debug("Starting DNS blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host,
zap.String("message", "Request blocked by DNS blacklist"),
zap.String("host", r.Host),
)
m.logger.Debug("DNS blacklist phase completed - blocked")
return
}

rules, ok := m.Rules[phase]
if !ok {
m.logger.Debug("No rules found for phase", zap.Int("phase", phase))
return
}

m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
for _, rule := range rules {
m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets)))

// Use the custom type as the key
ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID)
r = r.WithContext(ctx)

for _, target := range rule.Targets {
m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", string(rule.ID)))
var value string
var err error

if phase == 3 || phase == 4 {
if recorder, ok := w.(*responseRecorder); ok {
value, err = m.extractValue(target, r, recorder)
} else {
m.logger.Error("response recorder is not available in phase 3 or 4 when required")
value, err = m.extractValue(target, r, nil)
}
} else {
value, err = m.extractValue(target, r, nil)
}

if err != nil {
m.logger.Debug("Failed to extract value for target, skipping rule for this target",
zap.String("target", target),
zap.String("rule_id", string(rule.ID)),
zap.Error(err),
)
continue
}

m.logger.Debug("Extracted value",
zap.String("rule_id", string(rule.ID)),
zap.String("target", target),
zap.String("value", value),
)

if rule.regex.MatchString(value) {
m.logger.Debug("Rule matched",
zap.String("rule_id", string(rule.ID)),
zap.String("target", target),
zap.String("value", value),
)
if phase == 3 || phase == 4 {
if recorder, ok := w.(*responseRecorder); ok {
if !m.processRuleMatch(recorder, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
if state.Blocked || state.ResponseWritten {
m.logger.Debug("Rule evaluation completed early due to blocking or response written", zap.Int("phase", phase), zap.String("rule_id", string(rule.ID)))
return
}
} else {
m.logger.Debug("Rule did not match",
zap.String("rule_id", string(rule.ID)),
zap.String("target", target),
zap.String("value", value),
)
}
}
}
m.logger.Debug("Rule evaluation completed for phase", zap.Int("phase", phase))

if phase == 3 {
m.logger.Debug("Starting response headers phase")
if _, ok := w.(*responseRecorder); ok {
m.logger.Debug("Response headers phase completed")
}
}

if phase == 4 {
m.logger.Debug("Starting response body phase")
if _, ok := w.(*responseRecorder); ok {
m.logger.Debug("Response body phase completed")
}
}

m.logger.Debug("Completed phase evaluation",
zap.Int("phase", phase),
zap.Int("total_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold),
)
}
192 changes: 192 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package caddywaf
import (
"context"
"net/http"
"strings"

"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/google/uuid"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)

// ServeHTTP implements caddyhttp.Handler.
Expand Down Expand Up @@ -194,3 +196,193 @@ func (m *Middleware) copyResponse(w http.ResponseWriter, recorder *responseRecor
m.logger.Error("Failed to write recorded response body to client", zap.Error(err), zap.String("log_id", r.Context().Value("logID").(string)))
}
}

func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) {
m.logger.Debug("Starting phase evaluation",
zap.Int("phase", phase),
zap.String("source_ip", r.RemoteAddr),
zap.String("user_agent", r.UserAgent()),
)

if phase == 1 && m.CountryBlock.Enabled {
m.logger.Debug("Starting country blocking phase")
blocked, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP)
if err != nil {
m.logRequest(zapcore.ErrorLevel, "Failed to check country block",
r,
zap.Error(err),
)
m.blockRequest(w, r, state, http.StatusForbidden, "internal_error", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked due to internal error"),
)
m.logger.Debug("Country blocking phase completed - blocked due to error")
return
} else if blocked {
m.blockRequest(w, r, state, http.StatusForbidden, "country_block", "country_block_rule", r.RemoteAddr,
zap.String("message", "Request blocked by country"),
)
m.logger.Debug("Country blocking phase completed - blocked by country")
return
}
m.logger.Debug("Country blocking phase completed - not blocked")
}

if phase == 1 && m.rateLimiter != nil {
m.logger.Debug("Starting rate limiting phase")
ip := extractIP(r.RemoteAddr, m.logger) // Pass the logger here
path := r.URL.Path // Get the request path
if m.rateLimiter.isRateLimited(ip, path) {
m.blockRequest(w, r, state, http.StatusTooManyRequests, "rate_limit", "rate_limit_rule", r.RemoteAddr,
zap.String("message", "Request blocked by rate limit"),
)
m.logger.Debug("Rate limiting phase completed - blocked by rate limit")
return
}
m.logger.Debug("Rate limiting phase completed - not blocked")
}

if phase == 1 {
m.logger.Debug("Checking for IP blacklisting", zap.String("remote_addr", r.RemoteAddr)) //Added log for checking before to isIPBlacklisted call
xForwardedFor := r.Header.Get("X-Forwarded-For")
if xForwardedFor != "" {
ips := strings.Split(xForwardedFor, ",")
if len(ips) > 0 {
firstIP := strings.TrimSpace(ips[0])
m.logger.Debug("Checking IP blacklist with X-Forwarded-For", zap.String("remote_addr_xff", firstIP), zap.String("r.RemoteAddr", r.RemoteAddr))
if m.isIPBlacklisted(firstIP) {
m.logger.Debug("Starting IP blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", firstIP,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}
} else {
m.logger.Debug("X-Forwarded-For header present but empty or invalid")

}

} else {
m.logger.Debug("X-Forwarded-For header not present using r.RemoteAddr")
if m.isIPBlacklisted(r.RemoteAddr) {
m.logger.Debug("Starting IP blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "ip_blacklist", "ip_blacklist_rule", r.RemoteAddr,
zap.String("message", "Request blocked by IP blacklist"),
)
m.logger.Debug("IP blacklist phase completed - blocked")
return
}
}
}

if phase == 1 && m.isDNSBlacklisted(r.Host) {
m.logger.Debug("Starting DNS blacklist phase")
m.blockRequest(w, r, state, http.StatusForbidden, "dns_blacklist", "dns_blacklist_rule", r.Host,
zap.String("message", "Request blocked by DNS blacklist"),
zap.String("host", r.Host),
)
m.logger.Debug("DNS blacklist phase completed - blocked")
return
}

rules, ok := m.Rules[phase]
if !ok {
m.logger.Debug("No rules found for phase", zap.Int("phase", phase))
return
}

m.logger.Debug("Starting rule evaluation for phase", zap.Int("phase", phase), zap.Int("rule_count", len(rules)))
for _, rule := range rules {
m.logger.Debug("Processing rule", zap.String("rule_id", string(rule.ID)), zap.Int("target_count", len(rule.Targets)))

// Use the custom type as the key
ctx := context.WithValue(r.Context(), ContextKeyRule("rule_id"), rule.ID)
r = r.WithContext(ctx)

for _, target := range rule.Targets {
m.logger.Debug("Extracting value for target", zap.String("target", target), zap.String("rule_id", string(rule.ID)))
var value string
var err error

if phase == 3 || phase == 4 {
if recorder, ok := w.(*responseRecorder); ok {
value, err = m.extractValue(target, r, recorder)
} else {
m.logger.Error("response recorder is not available in phase 3 or 4 when required")
value, err = m.extractValue(target, r, nil)
}
} else {
value, err = m.extractValue(target, r, nil)
}

if err != nil {
m.logger.Debug("Failed to extract value for target, skipping rule for this target",
zap.String("target", target),
zap.String("rule_id", string(rule.ID)),
zap.Error(err),
)
continue
}

m.logger.Debug("Extracted value",
zap.String("rule_id", string(rule.ID)),
zap.String("target", target),
zap.String("value", value),
)

if rule.regex.MatchString(value) {
m.logger.Debug("Rule matched",
zap.String("rule_id", string(rule.ID)),
zap.String("target", target),
zap.String("value", value),
)
if phase == 3 || phase == 4 {
if recorder, ok := w.(*responseRecorder); ok {
if !m.processRuleMatch(recorder, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
} else {
if !m.processRuleMatch(w, r, &rule, value, state) {
return // Stop processing if the rule match indicates blocking
}
}
if state.Blocked || state.ResponseWritten {
m.logger.Debug("Rule evaluation completed early due to blocking or response written", zap.Int("phase", phase), zap.String("rule_id", string(rule.ID)))
return
}
} else {
m.logger.Debug("Rule did not match",
zap.String("rule_id", string(rule.ID)),
zap.String("target", target),
zap.String("value", value),
)
}
}
}
m.logger.Debug("Rule evaluation completed for phase", zap.Int("phase", phase))

if phase == 3 {
m.logger.Debug("Starting response headers phase")
if _, ok := w.(*responseRecorder); ok {
m.logger.Debug("Response headers phase completed")
}
}

if phase == 4 {
m.logger.Debug("Starting response body phase")
if _, ok := w.(*responseRecorder); ok {
m.logger.Debug("Response body phase completed")
}
}

m.logger.Debug("Completed phase evaluation",
zap.Int("phase", phase),
zap.Int("total_score", state.TotalScore),
zap.Int("anomaly_threshold", m.AnomalyThreshold),
)
}
Loading

0 comments on commit ee5fc71

Please sign in to comment.