From 77d5da916740c38d2b03a55a64c9c3cd9ea7a08f Mon Sep 17 00:00:00 2001 From: fabriziosalmi Date: Wed, 15 Jan 2025 01:16:47 +0100 Subject: [PATCH] modularizations --- Caddyfile | 8 ++- caddywaf.go | 169 +++---------------------------------------------- config.go | 143 +++++++++++++++++------------------------ helpers.go | 50 +++++++++++++++ ratelimiter.go | 124 ++++++++++++++++++++++++++++++++++++ 5 files changed, 248 insertions(+), 246 deletions(-) create mode 100644 helpers.go create mode 100644 ratelimiter.go diff --git a/Caddyfile b/Caddyfile index aea16ed..f170853 100644 --- a/Caddyfile +++ b/Caddyfile @@ -21,7 +21,11 @@ anomaly_threshold 10 block_countries GeoLite2-Country.mmdb RU CN KP # whitelist_countries GeoLite2-Country.mmdb US - rate_limit 10000 1m 5m + rate_limit { + requests 10 + window 10s + cleanup_interval 5m + } rule_file rules.json # rule_file rules/wordpress.json ip_blacklist_file ip_blacklist.txt @@ -42,4 +46,4 @@ respond "Hello world!" 200 } } -} +} \ No newline at end of file diff --git a/caddywaf.go b/caddywaf.go index a6286e1..77cd762 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "log" "net" "net/http" "os" @@ -52,7 +51,7 @@ func (m *Middleware) logVersion() { func init() { // Register the module and directive without logging - caddy.RegisterModule(Middleware{}) + caddy.RegisterModule(&Middleware{}) httpcaddyfile.RegisterHandlerDirective("waf", parseCaddyfile) } @@ -62,114 +61,6 @@ var ( _ caddyfile.Unmarshaler = (*Middleware)(nil) ) -// requestCounter struct -type requestCounter struct { - count int - window time.Time -} - -// RateLimit struct -type RateLimit struct { - Requests int `json:"requests"` - Window time.Duration `json:"window"` - CleanupInterval time.Duration `json:"cleanup_interval"` - Paths []string `json:"paths,omitempty"` // New: optional paths to apply rate limit - PathRegexes []*regexp.Regexp `json:"-"` // New: compiled regexes for the given paths - MatchAllPaths bool `json:"match_all_paths,omitempty"` -} - -// RateLimiter struct -type RateLimiter struct { - sync.RWMutex - requests map[string]*requestCounter - config RateLimit - stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop -} - -// isRateLimited checks if a given IP is rate limited. -func (rl *RateLimiter) isRateLimited(ip string) bool { - now := time.Now() - - rl.Lock() - defer rl.Unlock() - - counter, exists := rl.requests[ip] - if exists { - if now.Sub(counter.window) > rl.config.Window { - // Window expired, reset the counter - rl.requests[ip] = &requestCounter{count: 1, window: now} - return false - } - - // Window not expired, increment the counter - counter.count++ - return counter.count > rl.config.Requests - } - - // IP doesn't exist, add it - rl.requests[ip] = &requestCounter{count: 1, window: now} - return false -} - -// cleanupExpiredEntries removes expired entries from the rate limiter. -func (rl *RateLimiter) cleanupExpiredEntries() { - now := time.Now() - var expiredIPs []string - - // Collect expired IPs to delete (read lock) - rl.RLock() - for ip, counter := range rl.requests { - if now.Sub(counter.window) > rl.config.Window { - expiredIPs = append(expiredIPs, ip) - } - } - rl.RUnlock() - - // Delete expired IPs (write lock) - if len(expiredIPs) > 0 { - rl.Lock() - for _, ip := range expiredIPs { - delete(rl.requests, ip) - } - rl.Unlock() - } -} - -// startCleanup starts the goroutine to periodically clean up expired entries. -func (rl *RateLimiter) startCleanup() { - // Ensure stopCleanup channel is created only once - if rl.stopCleanup == nil { - rl.stopCleanup = make(chan struct{}) - } - - go func() { - log.Println("[INFO] Starting rate limiter cleanup goroutine") // Added logging - ticker := time.NewTicker(rl.config.CleanupInterval) // Use the specified cleanup interval - defer func() { - ticker.Stop() - log.Println("[INFO] Rate limiter cleanup goroutine stopped") // Added logging on exit - }() - for { - select { - case <-ticker.C: - rl.cleanupExpiredEntries() - case <-rl.stopCleanup: - return - } - } - }() -} - -// signalStopCleanup signals the cleanup goroutine to stop. -func (rl *RateLimiter) signalStopCleanup() { - if rl.stopCleanup != nil { - log.Println("[INFO] Signaling rate limiter cleanup goroutine to stop") // Added logging - close(rl.stopCleanup) - // We avoid setting rl.stopCleanup to nil here for maximum safety. - // Subsequent calls to signalStopCleanup will still be protected by the nil check. - } -} - // CountryAccessFilter struct type CountryAccessFilter struct { Enabled bool `json:"enabled"` @@ -213,13 +104,11 @@ type Middleware struct { IPBlacklistFile string `json:"ip_blacklist_file"` DNSBlacklistFile string `json:"dns_blacklist_file"` AnomalyThreshold int `json:"anomaly_threshold"` - RateLimit RateLimit `json:"rate_limit"` CountryBlock CountryAccessFilter `json:"country_block"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"` Rules map[int][]Rule `json:"-"` ipBlacklist map[string]bool `json:"-"` // Changed type here dnsBlacklist map[string]bool `json:"-"` - rateLimiter *RateLimiter `json:"-"` logger *zap.Logger LogSeverity string `json:"log_severity,omitempty"` LogJSON bool `json:"log_json,omitempty"` @@ -248,6 +137,9 @@ type Middleware struct { blacklistLoader *BlacklistLoader `json:"-"` geoIPHandler *GeoIPHandler `json:"-"` requestValueExtractor *RequestValueExtractor `json:"-"` + + RateLimit RateLimit `json:"rate_limit,omitempty"` + rateLimiter *RateLimiter `json:"-"` } // WAFState struct: Used to maintain state between phases @@ -258,7 +150,7 @@ type WAFState struct { ResponseWritten bool } -func (Middleware) CaddyModule() caddy.ModuleInfo { +func (*Middleware) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "http.handlers.waf", New: func() caddy.Module { return &Middleware{} }, @@ -278,7 +170,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error) err := m.UnmarshalCaddyfile(h.Dispenser) if err != nil { // Improve error message by including file and line number - return nil, fmt.Errorf("Caddyfile parse error: %w", err) + return nil, fmt.Errorf("caddyfile parse error: %w", err) } logger.Info("Successfully parsed Caddyfile", zap.String("file", h.Dispenser.File())) @@ -539,35 +431,6 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state } } -// extractIP extracts the IP address from a remote address string. -func extractIP(remoteAddr string) string { - if remoteAddr == "" { - return "" - } - - // Remove brackets from IPv6 addresses - if strings.HasPrefix(remoteAddr, "[") && strings.HasSuffix(remoteAddr, "]") { - remoteAddr = strings.TrimPrefix(remoteAddr, "[") - remoteAddr = strings.TrimSuffix(remoteAddr, "]") - } - - host, _, err := net.SplitHostPort(remoteAddr) - if err == nil { - return host - } - - ip := net.ParseIP(remoteAddr) - if ip != nil { - return ip.String() - } - - return "" -} - -func extractPath(r *http.Request) string { - return r.URL.Path -} - type responseRecorder struct { http.ResponseWriter body *bytes.Buffer @@ -1273,15 +1136,15 @@ func (m *Middleware) Provision(ctx caddy.Context) error { // Rate Limiter Setup if m.RateLimit.Requests > 0 { + if m.RateLimit.Window <= 0 { + return fmt.Errorf("invalid rate limit configuration: requests and window must be greater than zero") + } m.logger.Info("Rate limit configuration", zap.Int("requests", m.RateLimit.Requests), zap.Duration("window", m.RateLimit.Window), zap.Duration("cleanup_interval", m.RateLimit.CleanupInterval), ) - m.rateLimiter = &RateLimiter{ - requests: make(map[string]*requestCounter), - config: m.RateLimit, - } + m.rateLimiter = NewRateLimiter(m.RateLimit) m.rateLimiter.startCleanup() } else { m.logger.Info("Rate limiting is disabled") @@ -1527,18 +1390,6 @@ func (m *Middleware) loadRules(paths []string, ipBlacklistPath string, dnsBlackl return nil } -// fileExists checks if a file exists and is readable. -func fileExists(path string) bool { - if path == "" { - return false - } - info, err := os.Stat(path) - if os.IsNotExist(err) { - return false - } - return !info.IsDir() -} - func (m *Middleware) isIPBlacklisted(remoteAddr string) bool { ipStr := extractIP(remoteAddr) if ipStr == "" { diff --git a/config.go b/config.go index 94d195f..1dbd035 100644 --- a/config.go +++ b/config.go @@ -3,7 +3,6 @@ package caddywaf import ( "fmt" "os" - "regexp" "strconv" "strings" "time" @@ -20,6 +19,7 @@ type ConfigLoader struct { func NewConfigLoader(logger *zap.Logger) *ConfigLoader { return &ConfigLoader{logger: logger} } + func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware) error { if cl.logger == nil { cl.logger = zap.NewNop() @@ -62,9 +62,60 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware zap.Int("line", d.Line()), ) case "rate_limit": - if err := cl.parseRateLimit(d, m); err != nil { - return err + if m.RateLimit.Requests > 0 { + return d.Err("rate_limit specified multiple times") + } + + rl := RateLimit{} + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "requests": + if !d.NextArg() { + return d.Err("requests requires an argument") + } + reqs, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid requests value: %v", err) + } + rl.Requests = reqs + case "window": + if !d.NextArg() { + return d.Err("window requires an argument") + } + window, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("invalid window value: %v", err) + } + rl.Window = window + case "cleanup_interval": + if !d.NextArg() { + return d.Err("cleanup_interval requires an argument") + } + interval, err := time.ParseDuration(d.Val()) + if err != nil { + return d.Errf("invalid cleanup_interval value: %v", err) + } + rl.CleanupInterval = interval + case "paths": + rl.Paths = d.RemainingArgs() + case "match_all_paths": + rl.MatchAllPaths = true + default: + return d.Errf("invalid rate_limit option: %s", d.Val()) + } + } + if rl.Requests <= 0 || rl.Window <= 0 { + return d.Err("requests and window must be greater than zero") + } + // Create temporary RateLimit struct with basic fields only + m.RateLimit = RateLimit{ + Requests: rl.Requests, + Window: rl.Window, + CleanupInterval: rl.CleanupInterval, + Paths: rl.Paths, + MatchAllPaths: rl.MatchAllPaths, } + cl.logger.Debug("Rate limit parsed", zap.Any("rate_limit", m.RateLimit)) case "block_countries": if err := cl.parseCountryBlock(d, m, true); err != nil { return err @@ -109,7 +160,10 @@ func (cl *ConfigLoader) UnmarshalCaddyfile(d *caddyfile.Dispenser, m *Middleware } } } - return cl.validateConfig(m) + if len(m.RuleFiles) == 0 { + return fmt.Errorf("no rule files specified") + } + return nil } func (cl *ConfigLoader) parseRuleFile(d *caddyfile.Dispenser, m *Middleware) error { if !d.NextArg() { @@ -197,74 +251,6 @@ func (cl *ConfigLoader) parseCustomResponse(d *caddyfile.Dispenser, m *Middlewar } return nil } -func (cl *ConfigLoader) parseRateLimit(d *caddyfile.Dispenser, m *Middleware) error { - if !d.NextArg() { - return fmt.Errorf("File: %s, Line: %d: missing requests value for rate_limit", d.File(), d.Line()) - } - requests, err := strconv.Atoi(d.Val()) - if err != nil { - return fmt.Errorf("File: %s, Line: %d: invalid requests value for rate_limit: %v", d.File(), d.Line(), err) - } - - if !d.NextArg() { - return fmt.Errorf("File: %s, Line: %d: missing window duration for rate_limit", d.File(), d.Line()) - } - window, err := time.ParseDuration(d.Val()) - if err != nil { - return fmt.Errorf("File: %s, Line: %d: invalid duration for rate_limit: %v", d.File(), d.Line(), err) - } - - cleanupInterval := time.Minute - if d.NextArg() { - cleanupInterval, err = time.ParseDuration(d.Val()) - if err != nil { - return fmt.Errorf("File: %s, Line: %d: invalid cleanup interval: %v", d.File(), d.Line(), err) - } - } - - var paths []string - matchAllPaths := false - for d.NextArg() { - arg := d.Val() - if arg == "match_all_paths" { - matchAllPaths = true - cl.logger.Debug("Rate limiter match_all_paths enabled", zap.String("file", d.File()), zap.Int("line", d.Line())) - continue - } - paths = append(paths, arg) - } - - // Compile path regexes for all given paths - var pathRegexes []*regexp.Regexp - for _, path := range paths { - compiledRegex, err := regexp.Compile(path) - if err != nil { - return fmt.Errorf("File: %s, Line: %d: invalid regex in rate limit paths: %v", d.File(), d.Line(), err) - } - pathRegexes = append(pathRegexes, compiledRegex) - - } - - m.RateLimit = RateLimit{ - Requests: requests, - Window: window, - CleanupInterval: cleanupInterval, - Paths: paths, - PathRegexes: pathRegexes, - MatchAllPaths: matchAllPaths, - } - - cl.logger.Debug("Rate limit configured", - zap.Int("requests", requests), - zap.Duration("window", window), - zap.Duration("cleanup_interval", cleanupInterval), - zap.Strings("paths", paths), - zap.Bool("match_all_paths", matchAllPaths), - zap.String("file", d.File()), - zap.Int("line", d.Line()), - ) - return nil -} func (cl *ConfigLoader) parseCountryBlock(d *caddyfile.Dispenser, m *Middleware, isBlock bool) error { target := &m.CountryBlock @@ -330,16 +316,3 @@ func (cl *ConfigLoader) parseAnomalyThreshold(d *caddyfile.Dispenser, m *Middlew cl.logger.Debug("Anomaly threshold set", zap.Int("threshold", threshold)) return nil } - -func (cl *ConfigLoader) validateConfig(m *Middleware) error { - if m.RateLimit.Requests <= 0 || m.RateLimit.Window <= 0 { - return fmt.Errorf("invalid rate limit configuration: requests and window must be greater than zero") - } - if m.CountryBlock.Enabled && m.CountryBlock.GeoIPDBPath == "" { - return fmt.Errorf("country block is enabled but no GeoIP database path specified") - } - if len(m.RuleFiles) == 0 { - return fmt.Errorf("no rule files specified") - } - return nil -} diff --git a/helpers.go b/helpers.go new file mode 100644 index 0000000..d87ffc9 --- /dev/null +++ b/helpers.go @@ -0,0 +1,50 @@ +package caddywaf + +import ( + "net" + "net/http" + "os" + "strings" +) + +// extractIP extracts the IP address from a remote address string. +func extractIP(remoteAddr string) string { + if remoteAddr == "" { + return "" + } + + // Remove brackets from IPv6 addresses + if strings.HasPrefix(remoteAddr, "[") && strings.HasSuffix(remoteAddr, "]") { + remoteAddr = strings.TrimPrefix(remoteAddr, "[") + remoteAddr = strings.TrimSuffix(remoteAddr, "]") + } + + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + return host + } + + ip := net.ParseIP(remoteAddr) + if ip != nil { + return ip.String() + } + + return "" +} + +// extractPath extracts the path from a http.Request +func extractPath(r *http.Request) string { + return r.URL.Path +} + +// fileExists checks if a file exists and is readable. +func fileExists(path string) bool { + if path == "" { + return false + } + info, err := os.Stat(path) + if os.IsNotExist(err) { + return false + } + return !info.IsDir() +} diff --git a/ratelimiter.go b/ratelimiter.go new file mode 100644 index 0000000..7357faa --- /dev/null +++ b/ratelimiter.go @@ -0,0 +1,124 @@ +package caddywaf + +import ( + "log" + "regexp" + "sync" + "time" +) + +// requestCounter struct +type requestCounter struct { + count int + window time.Time +} + +// RateLimit struct +type RateLimit struct { + Requests int `json:"requests"` + Window time.Duration `json:"window"` + CleanupInterval time.Duration `json:"cleanup_interval"` + Paths []string `json:"paths,omitempty"` // New: optional paths to apply rate limit + PathRegexes []*regexp.Regexp `json:"-"` // New: compiled regexes for the given paths + MatchAllPaths bool `json:"match_all_paths,omitempty"` +} + +// RateLimiter struct +type RateLimiter struct { + sync.RWMutex + requests map[string]*requestCounter + config RateLimit + stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop +} + +// NewRateLimiter creates a new RateLimiter +func NewRateLimiter(config RateLimit) *RateLimiter { + return &RateLimiter{ + requests: make(map[string]*requestCounter), + config: config, + } +} + +// isRateLimited checks if a given IP is rate limited. +func (rl *RateLimiter) isRateLimited(ip string) bool { + now := time.Now() + + rl.Lock() + defer rl.Unlock() + + counter, exists := rl.requests[ip] + if exists { + if now.Sub(counter.window) > rl.config.Window { + // Window expired, reset the counter + rl.requests[ip] = &requestCounter{count: 1, window: now} + return false + } + + // Window not expired, increment the counter + counter.count++ + return counter.count > rl.config.Requests + } + + // IP doesn't exist, add it + rl.requests[ip] = &requestCounter{count: 1, window: now} + return false +} + +// cleanupExpiredEntries removes expired entries from the rate limiter. +func (rl *RateLimiter) cleanupExpiredEntries() { + now := time.Now() + var expiredIPs []string + + // Collect expired IPs to delete (read lock) + rl.RLock() + for ip, counter := range rl.requests { + if now.Sub(counter.window) > rl.config.Window { + expiredIPs = append(expiredIPs, ip) + } + } + rl.RUnlock() + + // Delete expired IPs (write lock) + if len(expiredIPs) > 0 { + rl.Lock() + for _, ip := range expiredIPs { + delete(rl.requests, ip) + } + rl.Unlock() + } +} + +// startCleanup starts the goroutine to periodically clean up expired entries. +func (rl *RateLimiter) startCleanup() { + // Ensure stopCleanup channel is created only once + if rl.stopCleanup == nil { + rl.stopCleanup = make(chan struct{}) + } + + go func() { + log.Println("[INFO] Starting rate limiter cleanup goroutine") // Added logging + ticker := time.NewTicker(rl.config.CleanupInterval) // Use the specified cleanup interval + defer func() { + ticker.Stop() + log.Println("[INFO] Rate limiter cleanup goroutine stopped") // Added logging on exit + }() + for { + select { + case <-ticker.C: + rl.cleanupExpiredEntries() + case <-rl.stopCleanup: + return + } + } + }() +} + +// signalStopCleanup signals the cleanup goroutine to stop. +func (rl *RateLimiter) signalStopCleanup() { + if rl.stopCleanup != nil { + log.Println("[INFO] Signaling rate limiter cleanup goroutine to stop") // Added logging + close(rl.stopCleanup) + // We avoid setting rl.stopCleanup to nil here for maximum safety. + // Subsequent calls to signalStopCleanup will still be protected by the nil check. + } +}