From 8bcb6b7c575c1207b97502737b295e774a74653d Mon Sep 17 00:00:00 2001 From: fabriziosalmi Date: Wed, 15 Jan 2025 11:12:25 +0100 Subject: [PATCH] Minor improvements --- blacklist.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++ caddywaf.go | 64 -------------------------------------------------- 2 files changed, 66 insertions(+), 64 deletions(-) diff --git a/blacklist.go b/blacklist.go index 4112e80..c1ea48b 100644 --- a/blacklist.go +++ b/blacklist.go @@ -1,10 +1,12 @@ package caddywaf import ( + "fmt" "net" "os" "strings" + "github.com/oschwald/maxminddb-golang" "go.uber.org/zap" ) @@ -128,3 +130,67 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma return nil } + +// isIPBlacklisted checks if the given IP address is in the blacklist. +func (m *Middleware) isIPBlacklisted(remoteAddr string) bool { + ipStr := extractIP(remoteAddr) + if ipStr == "" { + return false + } + + // Check if the IP is directly blacklisted + if m.ipBlacklist[ipStr] { + return true + } + + // Check if the IP falls within any CIDR range in the blacklist + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + + for blacklistEntry := range m.ipBlacklist { + if strings.Contains(blacklistEntry, "/") { + _, ipNet, err := net.ParseCIDR(blacklistEntry) + if err != nil { + continue + } + if ipNet.Contains(ip) { + return true + } + } + } + + return false +} + +func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, geoIP *maxminddb.Reader) (bool, error) { + if m.geoIPHandler == nil { + return false, fmt.Errorf("geoip handler not initialized") + } + return m.geoIPHandler.IsCountryInList(remoteAddr, countryList, geoIP) +} + +func (m *Middleware) isDNSBlacklisted(host string) bool { + normalizedHost := strings.ToLower(strings.TrimSpace(host)) + if normalizedHost == "" { + m.logger.Warn("Empty host provided for DNS blacklist check") + return false + } + + m.mu.RLock() + defer m.mu.RUnlock() + + if _, exists := m.dnsBlacklist[normalizedHost]; exists { + m.logger.Info("Host is blacklisted", + zap.String("host", host), + zap.String("blacklisted_domain", normalizedHost), + ) + return true + } + + m.logger.Debug("Host is not blacklisted", + zap.String("host", host), + ) + return false +} diff --git a/caddywaf.go b/caddywaf.go index c929554..4db5ff4 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -849,30 +849,6 @@ func (m *Middleware) handleMetricsRequest(w http.ResponseWriter, r *http.Request // ==================== Utility Functions ==================== -func (m *Middleware) isDNSBlacklisted(host string) bool { - normalizedHost := strings.ToLower(strings.TrimSpace(host)) - if normalizedHost == "" { - m.logger.Warn("Empty host provided for DNS blacklist check") - return false - } - - m.mu.RLock() - defer m.mu.RUnlock() - - if _, exists := m.dnsBlacklist[normalizedHost]; exists { - m.logger.Info("Host is blacklisted", - zap.String("host", host), - zap.String("blacklisted_domain", normalizedHost), - ) - return true - } - - m.logger.Debug("Host is not blacklisted", - zap.String("host", host), - ) - return false -} - func (m *Middleware) extractValue(target string, r *http.Request, w http.ResponseWriter) (string, error) { return m.requestValueExtractor.ExtractValue(target, r, w) } @@ -886,13 +862,6 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return m.configLoader.UnmarshalCaddyfile(d, m) } -func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, geoIP *maxminddb.Reader) (bool, error) { - if m.geoIPHandler == nil { - return false, fmt.Errorf("geoip handler not initialized") - } - return m.geoIPHandler.IsCountryInList(remoteAddr, countryList, geoIP) -} - func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) { m.logger.Debug("Starting phase evaluation", zap.Int("phase", phase), @@ -1083,36 +1052,3 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i zap.Int("anomaly_threshold", m.AnomalyThreshold), ) } - -// isIPBlacklisted checks if the given IP address is in the blacklist. -func (m *Middleware) isIPBlacklisted(remoteAddr string) bool { - ipStr := extractIP(remoteAddr) - if ipStr == "" { - return false - } - - // Check if the IP is directly blacklisted - if m.ipBlacklist[ipStr] { - return true - } - - // Check if the IP falls within any CIDR range in the blacklist - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } - - for blacklistEntry := range m.ipBlacklist { - if strings.Contains(blacklistEntry, "/") { - _, ipNet, err := net.ParseCIDR(blacklistEntry) - if err != nil { - continue - } - if ipNet.Contains(ip) { - return true - } - } - } - - return false -}