Skip to content

Commit

Permalink
modularizations
Browse files Browse the repository at this point in the history
  • Loading branch information
fabriziosalmi committed Jan 15, 2025
1 parent 417696c commit 77d5da9
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 246 deletions.
8 changes: 6 additions & 2 deletions Caddyfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
anomaly_threshold 10
block_countries GeoLite2-Country.mmdb RU CN KP
# whitelist_countries GeoLite2-Country.mmdb US
rate_limit 10000 1m 5m
rate_limit {
requests 10
window 10s
cleanup_interval 5m
}
rule_file rules.json
# rule_file rules/wordpress.json
ip_blacklist_file ip_blacklist.txt
Expand All @@ -42,4 +46,4 @@
respond "Hello world!" 200
}
}
}
}
169 changes: 10 additions & 159 deletions caddywaf.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"net/http"
"os"
Expand Down Expand Up @@ -52,7 +51,7 @@ func (m *Middleware) logVersion() {

func init() {
// Register the module and directive without logging
caddy.RegisterModule(Middleware{})
caddy.RegisterModule(&Middleware{})
httpcaddyfile.RegisterHandlerDirective("waf", parseCaddyfile)
}

Expand All @@ -62,114 +61,6 @@ var (
_ caddyfile.Unmarshaler = (*Middleware)(nil)
)

// requestCounter struct
type requestCounter struct {
count int
window time.Time
}

// RateLimit struct
type RateLimit struct {
Requests int `json:"requests"`
Window time.Duration `json:"window"`
CleanupInterval time.Duration `json:"cleanup_interval"`
Paths []string `json:"paths,omitempty"` // New: optional paths to apply rate limit
PathRegexes []*regexp.Regexp `json:"-"` // New: compiled regexes for the given paths
MatchAllPaths bool `json:"match_all_paths,omitempty"`
}

// RateLimiter struct
type RateLimiter struct {
sync.RWMutex
requests map[string]*requestCounter
config RateLimit
stopCleanup chan struct{} // Channel to signal cleanup goroutine to stop
}

// isRateLimited checks if a given IP is rate limited.
func (rl *RateLimiter) isRateLimited(ip string) bool {
now := time.Now()

rl.Lock()
defer rl.Unlock()

counter, exists := rl.requests[ip]
if exists {
if now.Sub(counter.window) > rl.config.Window {
// Window expired, reset the counter
rl.requests[ip] = &requestCounter{count: 1, window: now}
return false
}

// Window not expired, increment the counter
counter.count++
return counter.count > rl.config.Requests
}

// IP doesn't exist, add it
rl.requests[ip] = &requestCounter{count: 1, window: now}
return false
}

// cleanupExpiredEntries removes expired entries from the rate limiter.
func (rl *RateLimiter) cleanupExpiredEntries() {
now := time.Now()
var expiredIPs []string

// Collect expired IPs to delete (read lock)
rl.RLock()
for ip, counter := range rl.requests {
if now.Sub(counter.window) > rl.config.Window {
expiredIPs = append(expiredIPs, ip)
}
}
rl.RUnlock()

// Delete expired IPs (write lock)
if len(expiredIPs) > 0 {
rl.Lock()
for _, ip := range expiredIPs {
delete(rl.requests, ip)
}
rl.Unlock()
}
}

// startCleanup starts the goroutine to periodically clean up expired entries.
func (rl *RateLimiter) startCleanup() {
// Ensure stopCleanup channel is created only once
if rl.stopCleanup == nil {
rl.stopCleanup = make(chan struct{})
}

go func() {
log.Println("[INFO] Starting rate limiter cleanup goroutine") // Added logging
ticker := time.NewTicker(rl.config.CleanupInterval) // Use the specified cleanup interval
defer func() {
ticker.Stop()
log.Println("[INFO] Rate limiter cleanup goroutine stopped") // Added logging on exit
}()
for {
select {
case <-ticker.C:
rl.cleanupExpiredEntries()
case <-rl.stopCleanup:
return
}
}
}()
}

// signalStopCleanup signals the cleanup goroutine to stop.
func (rl *RateLimiter) signalStopCleanup() {
if rl.stopCleanup != nil {
log.Println("[INFO] Signaling rate limiter cleanup goroutine to stop") // Added logging
close(rl.stopCleanup)
// We avoid setting rl.stopCleanup to nil here for maximum safety.
// Subsequent calls to signalStopCleanup will still be protected by the nil check.
}
}

// CountryAccessFilter struct
type CountryAccessFilter struct {
Enabled bool `json:"enabled"`
Expand Down Expand Up @@ -213,13 +104,11 @@ type Middleware struct {
IPBlacklistFile string `json:"ip_blacklist_file"`
DNSBlacklistFile string `json:"dns_blacklist_file"`
AnomalyThreshold int `json:"anomaly_threshold"`
RateLimit RateLimit `json:"rate_limit"`
CountryBlock CountryAccessFilter `json:"country_block"`
CountryWhitelist CountryAccessFilter `json:"country_whitelist"`
Rules map[int][]Rule `json:"-"`
ipBlacklist map[string]bool `json:"-"` // Changed type here
dnsBlacklist map[string]bool `json:"-"`
rateLimiter *RateLimiter `json:"-"`
logger *zap.Logger
LogSeverity string `json:"log_severity,omitempty"`
LogJSON bool `json:"log_json,omitempty"`
Expand Down Expand Up @@ -248,6 +137,9 @@ type Middleware struct {
blacklistLoader *BlacklistLoader `json:"-"`
geoIPHandler *GeoIPHandler `json:"-"`
requestValueExtractor *RequestValueExtractor `json:"-"`

RateLimit RateLimit `json:"rate_limit,omitempty"`
rateLimiter *RateLimiter `json:"-"`
}

// WAFState struct: Used to maintain state between phases
Expand All @@ -258,7 +150,7 @@ type WAFState struct {
ResponseWritten bool
}

func (Middleware) CaddyModule() caddy.ModuleInfo {
func (*Middleware) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.handlers.waf",
New: func() caddy.Module { return &Middleware{} },
Expand All @@ -278,7 +170,7 @@ func parseCaddyfile(h httpcaddyfile.Helper) (caddyhttp.MiddlewareHandler, error)
err := m.UnmarshalCaddyfile(h.Dispenser)
if err != nil {
// Improve error message by including file and line number
return nil, fmt.Errorf("Caddyfile parse error: %w", err)
return nil, fmt.Errorf("caddyfile parse error: %w", err)
}

logger.Info("Successfully parsed Caddyfile", zap.String("file", h.Dispenser.File()))
Expand Down Expand Up @@ -539,35 +431,6 @@ func (m *Middleware) blockRequest(w http.ResponseWriter, r *http.Request, state
}
}

// extractIP extracts the IP address from a remote address string.
func extractIP(remoteAddr string) string {
if remoteAddr == "" {
return ""
}

// Remove brackets from IPv6 addresses
if strings.HasPrefix(remoteAddr, "[") && strings.HasSuffix(remoteAddr, "]") {
remoteAddr = strings.TrimPrefix(remoteAddr, "[")
remoteAddr = strings.TrimSuffix(remoteAddr, "]")
}

host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
return host
}

ip := net.ParseIP(remoteAddr)
if ip != nil {
return ip.String()
}

return ""
}

func extractPath(r *http.Request) string {
return r.URL.Path
}

type responseRecorder struct {
http.ResponseWriter
body *bytes.Buffer
Expand Down Expand Up @@ -1273,15 +1136,15 @@ func (m *Middleware) Provision(ctx caddy.Context) error {

// Rate Limiter Setup
if m.RateLimit.Requests > 0 {
if m.RateLimit.Window <= 0 {
return fmt.Errorf("invalid rate limit configuration: requests and window must be greater than zero")
}
m.logger.Info("Rate limit configuration",
zap.Int("requests", m.RateLimit.Requests),
zap.Duration("window", m.RateLimit.Window),
zap.Duration("cleanup_interval", m.RateLimit.CleanupInterval),
)
m.rateLimiter = &RateLimiter{
requests: make(map[string]*requestCounter),
config: m.RateLimit,
}
m.rateLimiter = NewRateLimiter(m.RateLimit)
m.rateLimiter.startCleanup()
} else {
m.logger.Info("Rate limiting is disabled")
Expand Down Expand Up @@ -1527,18 +1390,6 @@ func (m *Middleware) loadRules(paths []string, ipBlacklistPath string, dnsBlackl
return nil
}

// fileExists checks if a file exists and is readable.
func fileExists(path string) bool {
if path == "" {
return false
}
info, err := os.Stat(path)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}

func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
ipStr := extractIP(remoteAddr)
if ipStr == "" {
Expand Down
Loading

0 comments on commit 77d5da9

Please sign in to comment.