Skip to content

Commit

Permalink
Minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
fabriziosalmi committed Jan 15, 2025
1 parent 62bedf6 commit 8bcb6b7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 64 deletions.
66 changes: 66 additions & 0 deletions blacklist.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package caddywaf

import (
"fmt"
"net"
"os"
"strings"

"github.com/oschwald/maxminddb-golang"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -128,3 +130,67 @@ func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist ma

return nil
}

// isIPBlacklisted checks if the given IP address is in the blacklist.
func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
ipStr := extractIP(remoteAddr)
if ipStr == "" {
return false
}

// Check if the IP is directly blacklisted
if m.ipBlacklist[ipStr] {
return true
}

// Check if the IP falls within any CIDR range in the blacklist
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}

for blacklistEntry := range m.ipBlacklist {
if strings.Contains(blacklistEntry, "/") {
_, ipNet, err := net.ParseCIDR(blacklistEntry)
if err != nil {
continue
}
if ipNet.Contains(ip) {
return true
}
}
}

return false
}

func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, geoIP *maxminddb.Reader) (bool, error) {
if m.geoIPHandler == nil {
return false, fmt.Errorf("geoip handler not initialized")
}
return m.geoIPHandler.IsCountryInList(remoteAddr, countryList, geoIP)
}

func (m *Middleware) isDNSBlacklisted(host string) bool {
normalizedHost := strings.ToLower(strings.TrimSpace(host))
if normalizedHost == "" {
m.logger.Warn("Empty host provided for DNS blacklist check")
return false
}

m.mu.RLock()
defer m.mu.RUnlock()

if _, exists := m.dnsBlacklist[normalizedHost]; exists {
m.logger.Info("Host is blacklisted",
zap.String("host", host),
zap.String("blacklisted_domain", normalizedHost),
)
return true
}

m.logger.Debug("Host is not blacklisted",
zap.String("host", host),
)
return false
}
64 changes: 0 additions & 64 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -849,30 +849,6 @@ func (m *Middleware) handleMetricsRequest(w http.ResponseWriter, r *http.Request

// ==================== Utility Functions ====================

func (m *Middleware) isDNSBlacklisted(host string) bool {
normalizedHost := strings.ToLower(strings.TrimSpace(host))
if normalizedHost == "" {
m.logger.Warn("Empty host provided for DNS blacklist check")
return false
}

m.mu.RLock()
defer m.mu.RUnlock()

if _, exists := m.dnsBlacklist[normalizedHost]; exists {
m.logger.Info("Host is blacklisted",
zap.String("host", host),
zap.String("blacklisted_domain", normalizedHost),
)
return true
}

m.logger.Debug("Host is not blacklisted",
zap.String("host", host),
)
return false
}

func (m *Middleware) extractValue(target string, r *http.Request, w http.ResponseWriter) (string, error) {
return m.requestValueExtractor.ExtractValue(target, r, w)
}
Expand All @@ -886,13 +862,6 @@ func (m *Middleware) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return m.configLoader.UnmarshalCaddyfile(d, m)
}

func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, geoIP *maxminddb.Reader) (bool, error) {
if m.geoIPHandler == nil {
return false, fmt.Errorf("geoip handler not initialized")
}
return m.geoIPHandler.IsCountryInList(remoteAddr, countryList, geoIP)
}

func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase int, state *WAFState) {
m.logger.Debug("Starting phase evaluation",
zap.Int("phase", phase),
Expand Down Expand Up @@ -1083,36 +1052,3 @@ func (m *Middleware) handlePhase(w http.ResponseWriter, r *http.Request, phase i
zap.Int("anomaly_threshold", m.AnomalyThreshold),
)
}

// isIPBlacklisted checks if the given IP address is in the blacklist.
func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
ipStr := extractIP(remoteAddr)
if ipStr == "" {
return false
}

// Check if the IP is directly blacklisted
if m.ipBlacklist[ipStr] {
return true
}

// Check if the IP falls within any CIDR range in the blacklist
ip := net.ParseIP(ipStr)
if ip == nil {
return false
}

for blacklistEntry := range m.ipBlacklist {
if strings.Contains(blacklistEntry, "/") {
_, ipNet, err := net.ParseCIDR(blacklistEntry)
if err != nil {
continue
}
if ipNet.Contains(ip) {
return true
}
}
}

return false
}

0 comments on commit 8bcb6b7

Please sign in to comment.