Skip to content

Commit

Permalink
Minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
fabriziosalmi committed Jan 15, 2025
1 parent 8bcb6b7 commit 6ae8c31
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 348 deletions.
89 changes: 38 additions & 51 deletions blacklist.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,29 @@ import (
"go.uber.org/zap"
)

// BlacklistLoader struct
// BlacklistLoader handles loading IP and DNS blacklists from files.
type BlacklistLoader struct {
logger *zap.Logger
}

// NewBlacklistLoader creates a new BlacklistLoader with a given logger
// NewBlacklistLoader creates a new BlacklistLoader with the provided logger.
func NewBlacklistLoader(logger *zap.Logger) *BlacklistLoader {
return &BlacklistLoader{logger: logger}
}

// LoadIPBlacklistFromFile loads IP addresses from a file
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]bool) error {
// LoadIPBlacklistFromFile loads IP addresses from a file into the provided map.
func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[string]struct{}) error {
if bl.logger == nil {
bl.logger = zap.NewNop()
}
// Initialize the IP blacklist
// Log the attempt to load the IP blacklist file
bl.logger.Debug("Loading IP blacklist from file",
zap.String("file", path),
)
bl.logger.Debug("Loading IP blacklist from file", zap.String("file", path))

// Attempt to read the file
content, err := os.ReadFile(path)
if err != nil {
bl.logger.Warn("Failed to read IP blacklist file",
zap.String("file", path),
zap.Error(err),
)
return nil // Continue with an empty blacklist
bl.logger.Warn("Failed to read IP blacklist file", zap.String("file", path), zap.Error(err))
return fmt.Errorf("failed to read IP blacklist file: %w", err)
}

// Split the file content into lines
lines := strings.Split(string(content), "\n")
validEntries := 0

Expand All @@ -51,28 +42,22 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
continue // Skip empty lines and comments
}

// Check if the line is a valid IP or CIDR range
if _, _, err := net.ParseCIDR(line); err == nil {
// It's a valid CIDR range
ipBlacklist[line] = true
// Valid CIDR range
ipBlacklist[line] = struct{}{}
validEntries++
bl.logger.Debug("Added CIDR range to blacklist",
zap.String("cidr", line),
)
bl.logger.Debug("Added CIDR range to blacklist", zap.String("cidr", line))
continue
}

if ip := net.ParseIP(line); ip != nil {
// It's a valid IP address
ipBlacklist[line] = true
// Valid IP address
ipBlacklist[line] = struct{}{}
validEntries++
bl.logger.Debug("Added IP to blacklist",
zap.String("ip", line),
)
bl.logger.Debug("Added IP to blacklist", zap.String("ip", line))
continue
}

// Log invalid entries for debugging
bl.logger.Warn("Invalid IP or CIDR range in blacklist file, skipping",
zap.String("file", path),
zap.Int("line", i+1),
Expand All @@ -88,46 +73,36 @@ func (bl *BlacklistLoader) LoadIPBlacklistFromFile(path string, ipBlacklist map[
return nil
}

// LoadDNSBlacklistFromFile loads DNS entries from a file
func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist map[string]bool) error {
// LoadDNSBlacklistFromFile loads DNS entries from a file into the provided map.
func (bl *BlacklistLoader) LoadDNSBlacklistFromFile(path string, dnsBlacklist map[string]struct{}) error {
if bl.logger == nil {
bl.logger = zap.NewNop()
}
// Log the attempt to load the DNS blacklist file
bl.logger.Debug("Loading DNS blacklist from file",
zap.String("file", path),
)
bl.logger.Debug("Loading DNS blacklist from file", zap.String("file", path))

// Attempt to read the file
content, err := os.ReadFile(path)
if err != nil {
bl.logger.Warn("Failed to read DNS blacklist file",
zap.String("file", path),
zap.Error(err),
)
return nil // Continue with an empty blacklist
bl.logger.Warn("Failed to read DNS blacklist file", zap.String("file", path), zap.Error(err))
return fmt.Errorf("failed to read DNS blacklist file: %w", err)
}

// Convert all entries to lowercase and trim whitespace and add to the map
lines := strings.Split(string(content), "\n")
validEntriesCount := 0
validEntries := 0

for _, line := range lines {
line = strings.ToLower(strings.TrimSpace(line))
if line == "" || strings.HasPrefix(line, "#") {
continue // Skip empty lines and comments
}
dnsBlacklist[line] = true
validEntriesCount++
dnsBlacklist[line] = struct{}{}
validEntries++
}

// Log the successful loading of the DNS blacklist
bl.logger.Info("DNS blacklist loaded successfully",
zap.String("file", path),
zap.Int("valid_entries", validEntriesCount),
zap.Int("valid_entries", validEntries),
zap.Int("total_lines", len(lines)),
)

return nil
}

Expand All @@ -138,8 +113,11 @@ func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
return false
}

m.mu.RLock()
defer m.mu.RUnlock()

// Check if the IP is directly blacklisted
if m.ipBlacklist[ipStr] {
if _, exists := m.ipBlacklist[ipStr]; exists {
return true
}

Expand All @@ -164,13 +142,15 @@ func (m *Middleware) isIPBlacklisted(remoteAddr string) bool {
return false
}

// isCountryInList checks if the IP's country is in the provided list using the GeoIP database.
func (m *Middleware) isCountryInList(remoteAddr string, countryList []string, geoIP *maxminddb.Reader) (bool, error) {
if m.geoIPHandler == nil {
return false, fmt.Errorf("geoip handler not initialized")
}
return m.geoIPHandler.IsCountryInList(remoteAddr, countryList, geoIP)
}

// isDNSBlacklisted checks if the given host is in the DNS blacklist.
func (m *Middleware) isDNSBlacklisted(host string) bool {
normalizedHost := strings.ToLower(strings.TrimSpace(host))
if normalizedHost == "" {
Expand All @@ -189,8 +169,15 @@ func (m *Middleware) isDNSBlacklisted(host string) bool {
return true
}

m.logger.Debug("Host is not blacklisted",
zap.String("host", host),
)
m.logger.Debug("Host is not blacklisted", zap.String("host", host))
return false
}

// extractIP extracts the IP address from a remote address string.
func extractIP(remoteAddr string) string {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr // Assume the input is already an IP address
}
return host
}
Loading

0 comments on commit 6ae8c31

Please sign in to comment.