diff --git a/blacklist.go b/blacklist.go index c1ea48b..0e6cf47 100644 --- a/blacklist.go +++ b/blacklist.go @@ -10,38 +10,29 @@ import ( "go.uber.org/zap" ) -// BlacklistLoader struct +// BlacklistLoader handles loading IP and DNS blacklists from files. type BlacklistLoader struct { logger *zap.Logger } -// NewBlacklistLoader creates a new BlacklistLoader with a given logger +// NewBlacklistLoader creates a new BlacklistLoader with the provided logger. func NewBlacklistLoader(logger *zap.Logger) *BlacklistLoader { return &BlacklistLoader{logger: logger} } -// LoadIPBlacklistFromFile loads IP addresses from a file -func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]bool) error { +// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map. +func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error { if bl.logger == nil { bl.logger = zap.NewNop() } - // Initialize the IP blacklist - // Log the attempt to load the IP blacklist file - bl.logger.Debug("Loading IP blacklist from file", - zap.String("file", path), - ) + bl.logger.Debug("Loading IP blacklist from file", zap.String("file", path)) - // Attempt to read the file content, err := os.ReadFile(path) if err != nil { - bl.logger.Warn("Failed to read IP blacklist file", - zap.String("file", path), - zap.Error(err), - ) - return nil // Continue with an empty blacklist + bl.logger.Warn("Failed to read IP blacklist file", zap.String("file", path), zap.Error(err)) + return fmt.Errorf("failed to read IP blacklist file: %w", err) } - // Split the file content into lines lines := strings.Split(string(content), "\n") validEntries := 0 @@ -51,28 +42,22 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[ continue // Skip empty lines and comments } - // Check if the line is a valid IP or CIDR range if _, _, err := net.ParseCIDR(line); err == nil { - // It's a valid CIDR range - ipBlacklist[line] = true + // Valid CIDR range + ipBlacklist[line] = struct{}{} validEntries++ - bl.logger.Debug("Added CIDR range to blacklist", - zap.String("cidr", line), - ) + bl.logger.Debug("Added CIDR range to blacklist", zap.String("cidr", line)) continue } if ip := net.ParseIP(line); ip != nil { - // It's a valid IP address - ipBlacklist[line] = true + // Valid IP address + ipBlacklist[line] = struct{}{} validEntries++ - bl.logger.Debug("Added IP to blacklist", - zap.String("ip", line), - ) + bl.logger.Debug("Added IP to blacklist", zap.String("ip", line)) continue } - // Log invalid entries for debugging bl.logger.Warn("Invalid IP or CIDR range in blacklist file, skipping", zap.String("file", path), zap.Int("line", i+1), @@ -88,46 +73,36 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[ return nil } -// LoadDNSBlacklistFromFile loads DNS entries from a file -func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist map[string]bool) error { +// LoadDNSBlacklistFromFile loads DNS entries from a file into the provided map. +func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist map[string]struct{}) error { if bl.logger == nil { bl.logger = zap.NewNop() } - // Log the attempt to load the DNS blacklist file - bl.logger.Debug("Loading DNS blacklist from file", - zap.String("file", path), - ) + bl.logger.Debug("Loading DNS blacklist from file", zap.String("file", path)) - // Attempt to read the file content, err := os.ReadFile(path) if err != nil { - bl.logger.Warn("Failed to read DNS blacklist file", - zap.String("file", path), - zap.Error(err), - ) - return nil // Continue with an empty blacklist + bl.logger.Warn("Failed to read DNS blacklist file", zap.String("file", path), zap.Error(err)) + return fmt.Errorf("failed to read DNS blacklist file: %w", err) } - // Convert all entries to lowercase and trim whitespace and add to the map lines := strings.Split(string(content), "\n") - validEntriesCount := 0 + validEntries := 0 for _, line := range lines { line = strings.ToLower(strings.TrimSpace(line)) if line == "" || strings.HasPrefix(line, "#") { continue // Skip empty lines and comments } - dnsBlacklist[line] = true - validEntriesCount++ + dnsBlacklist[line] = struct{}{} + validEntries++ } - // Log the successful loading of the DNS blacklist bl.logger.Info("DNS blacklist loaded successfully", zap.String("file", path), - zap.Int("valid_entries", validEntriesCount), + zap.Int("valid_entries", validEntries), zap.Int("total_lines", len(lines)), ) - return nil } @@ -138,8 +113,11 @@ func (m *Middleware) isIPBlacklisted(remoteAddr string) bool { return false } + m.mu.RLock() + defer m.mu.RUnlock() + // Check if the IP is directly blacklisted - if m.ipBlacklist[ipStr] { + if _, exists := m.ipBlacklist[ipStr]; exists { return true } @@ -164,6 +142,7 @@ func (m *Middleware) isIPBlacklisted(remoteAddr string) bool { return false } +// isCountryInList checks if the IP's country is in the provided list using the GeoIP database. func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, geoIP *maxminddb.Reader) (bool, error) { if m.geoIPHandler == nil { return false, fmt.Errorf("geoip handler not initialized") @@ -171,6 +150,7 @@ func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, ge return m.geoIPHandler.IsCountryInList(remoteAddr, countryList, geoIP) } +// isDNSBlacklisted checks if the given host is in the DNS blacklist. func (m *Middleware) isDNSBlacklisted(host string) bool { normalizedHost := strings.ToLower(strings.TrimSpace(host)) if normalizedHost == "" { @@ -189,8 +169,15 @@ func (m *Middleware) isDNSBlacklisted(host string) bool { return true } - m.logger.Debug("Host is not blacklisted", - zap.String("host", host), - ) + m.logger.Debug("Host is not blacklisted", zap.String("host", host)) return false } + +// extractIP extracts the IP address from a remote address string. +func extractIP(remoteAddr string) string { + host, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return remoteAddr // Assume the input is already an IP address + } + return host +} diff --git a/caddywaf.go b/caddywaf.go index 4db5ff4..9854d1b 100644 --- a/caddywaf.go +++ b/caddywaf.go @@ -1,11 +1,9 @@ package caddywaf import ( - "bytes" "context" "encoding/json" "fmt" - "net" "net/http" "os" "regexp" @@ -17,7 +15,6 @@ import ( "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" "github.com/caddyserver/caddy/v2/modules/caddyhttp" - "github.com/google/uuid" "github.com/oschwald/maxminddb-golang" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -83,8 +80,8 @@ type Middleware struct { CountryBlock CountryAccessFilter `json:"country_block"` CountryWhitelist CountryAccessFilter `json:"country_whitelist"` Rules map[int][]Rule `json:"-"` - ipBlacklist map[string]bool `json:"-"` - dnsBlacklist map[string]bool `json:"-"` + ipBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{} + dnsBlacklist map[string]struct{} `json:"-"` // Changed to map[string]struct{} logger *zap.Logger LogSeverity string `json:"log_severity,omitempty"` LogJSON bool `json:"log_json,omitempty"` @@ -266,14 +263,14 @@ func (m *Middleware) Provision(ctx caddy.Context) error { return fmt.Errorf("failed to load config: %w", err) } - m.ipBlacklist = make(map[string]bool) + m.ipBlacklist = make(map[string]struct{}) // Changed to map[string]struct{} if m.IPBlacklistFile != "" { err = m.blacklistLoader.LoadIPBlacklistFromFile(m.IPBlacklistFile, m.ipBlacklist) if err != nil { return fmt.Errorf("failed to load IP blacklist: %w", err) } } - m.dnsBlacklist = make(map[string]bool) + m.dnsBlacklist = make(map[string]struct{}) // Changed to map[string]struct{} if m.DNSBlacklistFile != "" { err = m.blacklistLoader.LoadDNSBlacklistFromFile(m.DNSBlacklistFile, m.dnsBlacklist) if err != nil { @@ -345,263 +342,6 @@ func (m *Middleware) Shutdown(ctx context.Context) error { return firstError } -// ==================== Rule and Blacklist Management ==================== - -func (m *Middleware) loadRules(paths []string, ipBlacklistPath string, dnsBlacklistPath string) error { - m.mu.Lock() - defer m.mu.Unlock() - - m.logger.Debug("Loading rules and blacklists from files", zap.Strings("rule_files", paths), zap.String("ip_blacklist", ipBlacklistPath), zap.String("dns_blacklist", dnsBlacklistPath)) - - m.Rules = make(map[int][]Rule) - totalRules := 0 - var invalidFiles []string - var allInvalidRules []string - ruleIDs := make(map[string]bool) - - for _, path := range paths { - content, err := os.ReadFile(path) - if err != nil { - m.logger.Error("Failed to read rule file", zap.String("file", path), zap.Error(err)) - invalidFiles = append(invalidFiles, path) - continue - } - - var rules []Rule - if err := json.Unmarshal(content, &rules); err != nil { - m.logger.Error("Failed to unmarshal rules from file", zap.String("file", path), zap.Error(err)) - invalidFiles = append(invalidFiles, path) - continue - } - - var invalidRulesInFile []string - for i, rule := range rules { - if err := validateRule(&rule); err != nil { - invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule at index %d: %v", i, err)) - continue - } - - if _, exists := ruleIDs[rule.ID]; exists { - invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i)) - continue - } - ruleIDs[rule.ID] = true - - regex, err := regexp.Compile(rule.Pattern) - if err != nil { - m.logger.Error("Failed to compile regex for rule", zap.String("rule_id", rule.ID), zap.String("pattern", rule.Pattern), zap.Error(err)) - invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule '%s': invalid regex pattern: %v", rule.ID, err)) - continue - } - rule.regex = regex - - if _, ok := m.Rules[rule.Phase]; !ok { - m.Rules[rule.Phase] = []Rule{} - } - - m.Rules[rule.Phase] = append(m.Rules[rule.Phase], rule) - totalRules++ - } - if len(invalidRulesInFile) > 0 { - m.logger.Warn("Some rules failed validation", zap.String("file", path), zap.Strings("invalid_rules", invalidRulesInFile)) - allInvalidRules = append(allInvalidRules, invalidRulesInFile...) - } - - m.logger.Info("Rules loaded", zap.String("file", path), zap.Int("total_rules", len(rules)), zap.Int("invalid_rules", len(invalidRulesInFile))) - } - - m.ipBlacklist = make(map[string]bool) - if ipBlacklistPath != "" { - content, err := os.ReadFile(ipBlacklistPath) - if err != nil { - m.logger.Warn("Failed to read IP blacklist file", zap.String("file", ipBlacklistPath), zap.Error(err)) - } else { - lines := strings.Split(string(content), "\n") - validEntries := 0 - for i, line := range lines { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - if _, _, err := net.ParseCIDR(line); err == nil { - m.ipBlacklist[line] = true - validEntries++ - m.logger.Debug("Added CIDR range to blacklist", zap.String("cidr", line)) - continue - } - - if ip := net.ParseIP(line); ip != nil { - m.ipBlacklist[line] = true - validEntries++ - m.logger.Debug("Added IP to blacklist", zap.String("ip", line)) - continue - } - - m.logger.Warn("Invalid IP or CIDR range in blacklist file, skipping", - zap.String("file", ipBlacklistPath), - zap.Int("line", i+1), - zap.String("entry", line), - ) - } - m.logger.Info("IP blacklist loaded successfully", - zap.String("file", ipBlacklistPath), - zap.Int("valid_entries", validEntries), - zap.Int("total_lines", len(lines)), - ) - } - } - - m.dnsBlacklist = make(map[string]bool) - if dnsBlacklistPath != "" { - content, err := os.ReadFile(dnsBlacklistPath) - if err != nil { - m.logger.Warn("Failed to read DNS blacklist file", zap.String("file", dnsBlacklistPath), zap.Error(err)) - } else { - lines := strings.Split(string(content), "\n") - validEntriesCount := 0 - for _, line := range lines { - line = strings.ToLower(strings.TrimSpace(line)) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - m.dnsBlacklist[line] = true - validEntriesCount++ - } - m.logger.Info("DNS blacklist loaded successfully", - zap.String("file", dnsBlacklistPath), - zap.Int("valid_entries", validEntriesCount), - zap.Int("total_lines", len(lines)), - ) - } - } - - if len(invalidFiles) > 0 { - m.logger.Warn("Some rule files could not be loaded", zap.Strings("invalid_files", invalidFiles)) - } - if len(allInvalidRules) > 0 { - m.logger.Warn("Some rules across files failed validation", zap.Strings("invalid_rules", allInvalidRules)) - } - - if totalRules == 0 && len(invalidFiles) > 0 { - return fmt.Errorf("no valid rules were loaded from any file") - } - m.logger.Debug("Rules and Blacklists loaded successfully", zap.Int("total_rules", totalRules)) - - return nil -} - -// ==================== Request Handling ==================== - -func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { - logID := uuid.New().String() - ctx := context.WithValue(r.Context(), "logID", logID) - r = r.WithContext(ctx) - - m.logRequest(zapcore.DebugLevel, "Entering ServeHTTP", zap.String("path", r.URL.Path)) - - state := &WAFState{ - TotalScore: 0, - Blocked: false, - StatusCode: http.StatusOK, - ResponseWritten: false, - } - - m.logger.Info("WAF evaluation started", - zap.String("log_id", logID), - zap.String("method", r.Method), - zap.String("path", r.URL.Path), - zap.String("source_ip", r.RemoteAddr), - zap.String("user_agent", r.UserAgent()), - zap.String("query_params", r.URL.RawQuery), - ) - - block := func(statusCode int, reason string, fields ...zap.Field) error { - if !state.ResponseWritten { - m.blockRequest(w, r, state, statusCode, append(fields, zap.String("reason", reason))...) - return nil - } - m.logger.Debug("Blocking action skipped, response already written", zap.String("log_id", logID), zap.String("reason", reason)) - return nil - } - - if m.CountryWhitelist.Enabled { - whitelisted, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP) - if err != nil { - m.logRequest(zapcore.ErrorLevel, "Failed to check whitelist", zap.String("log_id", logID), zap.Error(err)) - } else if whitelisted { - m.logRequest(zapcore.InfoLevel, "Request allowed - country whitelisted", zap.String("log_id", logID)) - return next.ServeHTTP(w, r) - } - } - - if m.CountryBlock.Enabled { - blacklisted, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP) - if err != nil { - m.logRequest(zapcore.ErrorLevel, "Failed to check blacklist", zap.String("log_id", logID), zap.Error(err)) - return block(http.StatusInternalServerError, "blacklist_check_error") - } else if blacklisted { - m.logRequest(zapcore.WarnLevel, "Request blocked - country blacklisted", zap.String("log_id", logID)) - return block(http.StatusForbidden, "country_blacklist") - } - } - - m.handlePhase(w, r, 1, state) - if state.Blocked { - w.WriteHeader(state.StatusCode) - return nil - } - - m.handlePhase(w, r, 2, state) - if state.Blocked { - w.WriteHeader(state.StatusCode) - return nil - } - - recorder := &responseRecorder{ResponseWriter: w, body: new(bytes.Buffer)} - err := next.ServeHTTP(recorder, r) - - m.handlePhase(recorder, r, 3, state) - if state.Blocked { - recorder.WriteHeader(state.StatusCode) - return nil - } - - if recorder.body != nil { - body := recorder.body.String() - m.logger.Debug("Response body captured", zap.String("body", body)) - - for _, rule := range m.Rules[4] { - if rule.regex.MatchString(body) { - m.processRuleMatch(recorder, r, &rule, body, state) - if state.Blocked { - recorder.WriteHeader(state.StatusCode) - return nil - } - } - } - - if !state.ResponseWritten { - _, writeErr := w.Write(recorder.body.Bytes()) - if writeErr != nil { - m.logger.Error("Failed to write response body", zap.Error(writeErr)) - } - } - } - - if m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint { - return m.handleMetricsRequest(w, r) - } - - m.logger.Info("WAF evaluation complete", - zap.String("log_id", logID), - zap.Int("total_score", state.TotalScore), - zap.Bool("blocked", state.Blocked), - ) - - return err -} - // ==================== Helper Functions ==================== func (m *Middleware) logVersion() { @@ -728,7 +468,7 @@ func (m *Middleware) ReloadConfig() error { return fmt.Errorf("failed to reload rules: %v", err) } - newIPBlacklist := make(map[string]bool) + newIPBlacklist := make(map[string]struct{}) // Changed to map[string]struct{} if m.IPBlacklistFile != "" { if err := m.loadIPBlacklistIntoMap(m.IPBlacklistFile, newIPBlacklist); err != nil { m.logger.Error("Failed to reload IP blacklist", zap.String("file", m.IPBlacklistFile), zap.Error(err)) @@ -738,7 +478,7 @@ func (m *Middleware) ReloadConfig() error { m.logger.Debug("No IP blacklist file specified, skipping reload") } - newDNSBlacklist := make(map[string]bool) + newDNSBlacklist := make(map[string]struct{}) // Changed to map[string]struct{} if m.DNSBlacklistFile != "" { if err := m.loadDNSBlacklistIntoMap(m.DNSBlacklistFile, newDNSBlacklist); err != nil { m.logger.Error("Failed to reload DNS blacklist", zap.String("file", m.DNSBlacklistFile), zap.Error(err)) @@ -781,7 +521,7 @@ func (m *Middleware) loadRulesIntoMap(rulesMap map[int][]Rule) error { return nil } -func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap map[string]bool) error { +func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap map[string]struct{}) error { content, err := os.ReadFile(path) if err != nil { return fmt.Errorf("failed to read IP blacklist file: %v", err) @@ -793,12 +533,12 @@ func (m *Middleware) loadIPBlacklistIntoMap(path string, blacklistMap map[string if line == "" || strings.HasPrefix(line, "#") { continue } - blacklistMap[line] = true + blacklistMap[line] = struct{}{} // Changed to struct{}{} } return nil } -func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[string]bool) error { +func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[string]struct{}) error { content, err := os.ReadFile(path) if err != nil { return fmt.Errorf("failed to read DNS blacklist file: %v", err) @@ -810,7 +550,7 @@ func (m *Middleware) loadDNSBlacklistIntoMap(path string, blacklistMap map[strin if line == "" || strings.HasPrefix(line, "#") { continue } - blacklistMap[line] = true + blacklistMap[line] = struct{}{} // Changed to struct{}{} } return nil } diff --git a/helpers.go b/helpers.go index 406715d..f1604a4 100644 --- a/helpers.go +++ b/helpers.go @@ -1,36 +1,9 @@ package caddywaf import ( - "net" "os" - "strings" ) -// extractIP extracts the IP address from a remote address string. -func extractIP(remoteAddr string) string { - if remoteAddr == "" { - return "" - } - - // Remove brackets from IPv6 addresses - if strings.HasPrefix(remoteAddr, "[") && strings.HasSuffix(remoteAddr, "]") { - remoteAddr = strings.TrimPrefix(remoteAddr, "[") - remoteAddr = strings.TrimSuffix(remoteAddr, "]") - } - - host, _, err := net.SplitHostPort(remoteAddr) - if err == nil { - return host - } - - ip := net.ParseIP(remoteAddr) - if ip != nil { - return ip.String() - } - - return "" -} - // fileExists checks if a file exists and is readable. func fileExists(path string) bool { if path == "" { diff --git a/request.go b/request.go index 13ec912..299a07e 100644 --- a/request.go +++ b/request.go @@ -2,6 +2,7 @@ package caddywaf import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -9,7 +10,10 @@ import ( "strconv" "strings" + "github.com/caddyserver/caddy/v2/modules/caddyhttp" + "github.com/google/uuid" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) // RequestValueExtractor struct @@ -317,3 +321,112 @@ func (rve *RequestValueExtractor) extractJSONPath(jsonStr string, jsonPath strin } return fmt.Sprintf("%v", current), nil // Convert value to string (if possible) } + +func (m *Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { + logID := uuid.New().String() + ctx := context.WithValue(r.Context(), "logID", logID) + r = r.WithContext(ctx) + + m.logRequest(zapcore.DebugLevel, "Entering ServeHTTP", zap.String("path", r.URL.Path)) + + state := &WAFState{ + TotalScore: 0, + Blocked: false, + StatusCode: http.StatusOK, + ResponseWritten: false, + } + + m.logger.Info("WAF evaluation started", + zap.String("log_id", logID), + zap.String("method", r.Method), + zap.String("path", r.URL.Path), + zap.String("source_ip", r.RemoteAddr), + zap.String("user_agent", r.UserAgent()), + zap.String("query_params", r.URL.RawQuery), + ) + + block := func(statusCode int, reason string, fields ...zap.Field) error { + if !state.ResponseWritten { + m.blockRequest(w, r, state, statusCode, append(fields, zap.String("reason", reason))...) + return nil + } + m.logger.Debug("Blocking action skipped, response already written", zap.String("log_id", logID), zap.String("reason", reason)) + return nil + } + + if m.CountryWhitelist.Enabled { + whitelisted, err := m.isCountryInList(r.RemoteAddr, m.CountryWhitelist.CountryList, m.CountryWhitelist.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check whitelist", zap.String("log_id", logID), zap.Error(err)) + } else if whitelisted { + m.logRequest(zapcore.InfoLevel, "Request allowed - country whitelisted", zap.String("log_id", logID)) + return next.ServeHTTP(w, r) + } + } + + if m.CountryBlock.Enabled { + blacklisted, err := m.isCountryInList(r.RemoteAddr, m.CountryBlock.CountryList, m.CountryBlock.geoIP) + if err != nil { + m.logRequest(zapcore.ErrorLevel, "Failed to check blacklist", zap.String("log_id", logID), zap.Error(err)) + return block(http.StatusInternalServerError, "blacklist_check_error") + } else if blacklisted { + m.logRequest(zapcore.WarnLevel, "Request blocked - country blacklisted", zap.String("log_id", logID)) + return block(http.StatusForbidden, "country_blacklist") + } + } + + m.handlePhase(w, r, 1, state) + if state.Blocked { + w.WriteHeader(state.StatusCode) + return nil + } + + m.handlePhase(w, r, 2, state) + if state.Blocked { + w.WriteHeader(state.StatusCode) + return nil + } + + recorder := &responseRecorder{ResponseWriter: w, body: new(bytes.Buffer)} + err := next.ServeHTTP(recorder, r) + + m.handlePhase(recorder, r, 3, state) + if state.Blocked { + recorder.WriteHeader(state.StatusCode) + return nil + } + + if recorder.body != nil { + body := recorder.body.String() + m.logger.Debug("Response body captured", zap.String("body", body)) + + for _, rule := range m.Rules[4] { + if rule.regex.MatchString(body) { + m.processRuleMatch(recorder, r, &rule, body, state) + if state.Blocked { + recorder.WriteHeader(state.StatusCode) + return nil + } + } + } + + if !state.ResponseWritten { + _, writeErr := w.Write(recorder.body.Bytes()) + if writeErr != nil { + m.logger.Error("Failed to write response body", zap.Error(writeErr)) + } + } + } + + if m.MetricsEndpoint != "" && r.URL.Path == m.MetricsEndpoint { + return m.handleMetricsRequest(w, r) + } + + m.logger.Info("WAF evaluation complete", + zap.String("log_id", logID), + zap.Int("total_score", state.TotalScore), + zap.Bool("blocked", state.Blocked), + ) + + return err +} diff --git a/rules.go b/rules.go index 94461e5..23a5854 100644 --- a/rules.go +++ b/rules.go @@ -2,8 +2,12 @@ package caddywaf import ( + "encoding/json" "fmt" + "net" "net/http" + "os" + "regexp" "strings" "time" @@ -155,3 +159,147 @@ func validateRule(rule *Rule) error { } return nil } + +func (m *Middleware) loadRules(paths []string, ipBlacklistPath string, dnsBlacklistPath string) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.logger.Debug("Loading rules and blacklists from files", zap.Strings("rule_files", paths), zap.String("ip_blacklist", ipBlacklistPath), zap.String("dns_blacklist", dnsBlacklistPath)) + + m.Rules = make(map[int][]Rule) + totalRules := 0 + var invalidFiles []string + var allInvalidRules []string + ruleIDs := make(map[string]bool) + + for _, path := range paths { + content, err := os.ReadFile(path) + if err != nil { + m.logger.Error("Failed to read rule file", zap.String("file", path), zap.Error(err)) + invalidFiles = append(invalidFiles, path) + continue + } + + var rules []Rule + if err := json.Unmarshal(content, &rules); err != nil { + m.logger.Error("Failed to unmarshal rules from file", zap.String("file", path), zap.Error(err)) + invalidFiles = append(invalidFiles, path) + continue + } + + var invalidRulesInFile []string + for i, rule := range rules { + if err := validateRule(&rule); err != nil { + invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule at index %d: %v", i, err)) + continue + } + + if _, exists := ruleIDs[rule.ID]; exists { + invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Duplicate rule ID '%s' at index %d", rule.ID, i)) + continue + } + ruleIDs[rule.ID] = true + + regex, err := regexp.Compile(rule.Pattern) + if err != nil { + m.logger.Error("Failed to compile regex for rule", zap.String("rule_id", rule.ID), zap.String("pattern", rule.Pattern), zap.Error(err)) + invalidRulesInFile = append(invalidRulesInFile, fmt.Sprintf("Rule '%s': invalid regex pattern: %v", rule.ID, err)) + continue + } + rule.regex = regex + + if _, ok := m.Rules[rule.Phase]; !ok { + m.Rules[rule.Phase] = []Rule{} + } + + m.Rules[rule.Phase] = append(m.Rules[rule.Phase], rule) + totalRules++ + } + if len(invalidRulesInFile) > 0 { + m.logger.Warn("Some rules failed validation", zap.String("file", path), zap.Strings("invalid_rules", invalidRulesInFile)) + allInvalidRules = append(allInvalidRules, invalidRulesInFile...) + } + + m.logger.Info("Rules loaded", zap.String("file", path), zap.Int("total_rules", len(rules)), zap.Int("invalid_rules", len(invalidRulesInFile))) + } + + m.ipBlacklist = make(map[string]struct{}) // Changed to map[string]struct{} + if ipBlacklistPath != "" { + content, err := os.ReadFile(ipBlacklistPath) + if err != nil { + m.logger.Warn("Failed to read IP blacklist file", zap.String("file", ipBlacklistPath), zap.Error(err)) + } else { + lines := strings.Split(string(content), "\n") + validEntries := 0 + for i, line := range lines { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + if _, _, err := net.ParseCIDR(line); err == nil { + m.ipBlacklist[line] = struct{}{} // Changed to struct{}{} + validEntries++ + m.logger.Debug("Added CIDR range to blacklist", zap.String("cidr", line)) + continue + } + + if ip := net.ParseIP(line); ip != nil { + m.ipBlacklist[line] = struct{}{} // Changed to struct{}{} + validEntries++ + m.logger.Debug("Added IP to blacklist", zap.String("ip", line)) + continue + } + + m.logger.Warn("Invalid IP or CIDR range in blacklist file, skipping", + zap.String("file", ipBlacklistPath), + zap.Int("line", i+1), + zap.String("entry", line), + ) + } + m.logger.Info("IP blacklist loaded successfully", + zap.String("file", ipBlacklistPath), + zap.Int("valid_entries", validEntries), + zap.Int("total_lines", len(lines)), + ) + } + } + + m.dnsBlacklist = make(map[string]struct{}) // Changed to map[string]struct{} + if dnsBlacklistPath != "" { + content, err := os.ReadFile(dnsBlacklistPath) + if err != nil { + m.logger.Warn("Failed to read DNS blacklist file", zap.String("file", dnsBlacklistPath), zap.Error(err)) + } else { + lines := strings.Split(string(content), "\n") + validEntriesCount := 0 + for _, line := range lines { + line = strings.ToLower(strings.TrimSpace(line)) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + m.dnsBlacklist[line] = struct{}{} // Changed to struct{}{} + validEntriesCount++ + } + m.logger.Info("DNS blacklist loaded successfully", + zap.String("file", dnsBlacklistPath), + zap.Int("valid_entries", validEntriesCount), + zap.Int("total_lines", len(lines)), + ) + } + } + + if len(invalidFiles) > 0 { + m.logger.Warn("Some rule files could not be loaded", zap.Strings("invalid_files", invalidFiles)) + } + if len(allInvalidRules) > 0 { + m.logger.Warn("Some rules across files failed validation", zap.Strings("invalid_rules", allInvalidRules)) + } + + if totalRules == 0 && len(invalidFiles) > 0 { + return fmt.Errorf("no valid rules were loaded from any file") + } + m.logger.Debug("Rules and Blacklists loaded successfully", zap.Int("total_rules", totalRules)) + + return nil +}