From 27b792cc705b5804dabb31c939a73121f4fc4ede Mon Sep 17 00:00:00 2001 From: Daniel Rocha <68558152+danroc@users.noreply.github.com> Date: Sat, 23 Nov 2024 09:19:07 +0100 Subject: [PATCH] feat: allow wildcard domains in config (#42) * feat: allow wildcard domains in config * chore: rename file `loader.go` to `reader.go` --- cmd/geoblock/main.go | 14 +- pkg/config/{loader.go => reader.go} | 26 +++- pkg/config/reader_test.go | 217 ++++++++++++++++++++++++++++ pkg/config/schema.go | 2 +- 4 files changed, 251 insertions(+), 8 deletions(-) rename pkg/config/{loader.go => reader.go} (51%) create mode 100644 pkg/config/reader_test.go diff --git a/cmd/geoblock/main.go b/cmd/geoblock/main.go index 319722b..efd211e 100644 --- a/cmd/geoblock/main.go +++ b/cmd/geoblock/main.go @@ -2,6 +2,7 @@ package main import ( + "bytes" "os" "time" @@ -51,6 +52,15 @@ func autoUpdate(resolver *iprange.Resolver) { } } +// loadConfig reads the configuration file from the given path and returns it. +func loadConfig(path string) (*config.Configuration, error) { + file, err := os.ReadFile(path) // #nosec G304 + if err != nil { + return nil, err + } + return config.ReadConfig(bytes.NewReader(file)) +} + // hasChanged returns true if the two file infos are different. It only checks // the size and the modification time. func hasChanged(a, b os.FileInfo) bool { @@ -78,7 +88,7 @@ func autoReload(engine *rules.Engine, path string) { } prevStat = stat - cfg, err := config.LoadConfig(path) + cfg, err := loadConfig(path) if err != nil { log.Errorf("Cannot read configuration file: %v", err) continue @@ -110,7 +120,7 @@ func main() { configureLogger(options.logLevel) log.Info("Loading configuration file") - cfg, err := config.LoadConfig(options.configPath) + cfg, err := loadConfig(options.configPath) if err != nil { log.Fatalf("Cannot read configuration file: %v", err) } diff --git a/pkg/config/loader.go b/pkg/config/reader.go similarity index 51% rename from pkg/config/loader.go rename to pkg/config/reader.go index 6eba40b..30efc05 100644 --- a/pkg/config/loader.go +++ b/pkg/config/reader.go @@ -3,12 +3,27 @@ package config import ( - "os" + "io" + "regexp" "github.com/go-playground/validator/v10" "gopkg.in/yaml.v3" ) +// DomainNameRegex matches a valid domain name as per RFC 1035. It also allows +// labels to be a single `*` wildcard. +var domainNameRegex = regexp.MustCompile( + `^(\*|[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)(\.(\*|[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?))*$`, +) + +func isDomainNameField(field validator.FieldLevel) bool { + domain, ok := field.Field().Interface().(string) + if !ok { + return false + } + return domainNameRegex.MatchString(domain) +} + // isCIDRField checks if the value of the given field is a valid CIDR. func isCIDRField(field validator.FieldLevel) bool { cidr, ok := field.Field().Interface().(CIDR) @@ -26,7 +41,8 @@ func read(data []byte) (*Configuration, error) { } validate := validator.New() - validate.RegisterValidation("cidr", isCIDRField) // #nosec G104 + validate.RegisterValidation("cidr", isCIDRField) // #nosec G104 + validate.RegisterValidation("domain", isDomainNameField) // #nosec G104 if err := validate.Struct(config); err != nil { return nil, err @@ -35,9 +51,9 @@ func read(data []byte) (*Configuration, error) { return &config, nil } -// LoadConfig reads the configuration from the given file. -func LoadConfig(filename string) (*Configuration, error) { - data, err := os.ReadFile(filename) // #nosec G304 +// ReadConfig reads the configuration from the given reader and returns it. +func ReadConfig(reader io.Reader) (*Configuration, error) { + data, err := io.ReadAll(reader) if err != nil { return nil, err } diff --git a/pkg/config/reader_test.go b/pkg/config/reader_test.go new file mode 100644 index 0000000..ed24637 --- /dev/null +++ b/pkg/config/reader_test.go @@ -0,0 +1,217 @@ +package config_test + +import ( + "errors" + "net" + "reflect" + "strings" + "testing" + + "github.com/danroc/geoblock/pkg/config" +) + +const validConfig = ` +access_control: + default_policy: allow + rules: + - networks: + - "10.0.0.0/8" + - "127.0.0.0/8" + domains: + - "example.com" + - "*.example.com" + methods: + - GET + - POST + countries: + - US + - FR + autonomous_systems: + - 1234 + - 5678 + policy: allow + + - policy: deny +` + +const invalidLeadingDot = ` +access_control: + default_policy: allow + rules: + - domains: + - ".example.com" + policy: allow +` + +const invalidWildcardLocation = ` +access_control: + default_policy: allow + rules: + - domains: + - "*example.com" + policy: allow +` + +const invalidDomainChar = ` +access_control: + default_policy: allow + rules: + - domains: + - "example?.com" + policy: allow +` + +const invalidLeadingDash = ` +access_control: + default_policy: allow + rules: + - domains: + - "-example.com" + policy: allow +` + +const invalidTrailingDash = ` +access_control: + default_policy: allow + rules: + - domains: + - "example-.com" + policy: allow +` + +const invalidDomainString = ` +access_control: + default_policy: allow + rules: + - domains: + - false + policy: allow +` + +const invalidNetworkString = ` +access_control: + default_policy: allow + rules: + - networks: + - "invalid" + policy: allow +` + +const invalidNetworkNumber = ` +access_control: + default_policy: allow + rules: + - networks: + - 10 + policy: allow +` + +const invalidNetworkRange = ` +access_control: + default_policy: allow + rules: + - networks: + - 300.300.300.300/50 + policy: allow +` + +func TestReadConfigValid(t *testing.T) { + tests := []struct { + name string + data string + expected *config.Configuration + }{ + { + "valid configuration", + validConfig, + &config.Configuration{ + AccessControl: config.AccessControl{ + DefaultPolicy: "allow", + Rules: []config.AccessControlRule{ + { + Policy: "allow", + Networks: []config.CIDR{ + { + IPNet: &net.IPNet{ + IP: net.IP{10, 0, 0, 0}, + Mask: net.CIDRMask(8, 32), + }, + }, + { + IPNet: &net.IPNet{ + IP: net.IP{127, 0, 0, 0}, + Mask: net.CIDRMask(8, 32), + }, + }, + }, + Domains: []string{ + "example.com", + "*.example.com", + }, + Methods: []string{"GET", "POST"}, + Countries: []string{"US", "FR"}, + AutonomousSystems: []uint32{1234, 5678}, + }, + { + Policy: "deny", + Networks: nil, + Domains: nil, + Methods: nil, + Countries: nil, + AutonomousSystems: nil, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + reader := strings.NewReader(test.data) + cfg, err := config.ReadConfig(reader) + if err != nil { + t.Errorf("%s: unexpected error: %v", test.name, err) + } + if !reflect.DeepEqual(*cfg, *test.expected) { + t.Errorf("%s: expected %v, got %v", test.name, test.expected, cfg) + } + } +} + +func TestReadConfigErr(t *testing.T) { + tests := []struct { + name string + data string + }{ + {"invalid leading dot", invalidLeadingDot}, + {"invalid wildcard location", invalidWildcardLocation}, + {"invalid domain character", invalidDomainChar}, + {"invalid leading dash", invalidLeadingDash}, + {"invalid trailing dash", invalidTrailingDash}, + {"invalid network string", invalidNetworkString}, + {"invalid network number", invalidNetworkNumber}, + {"invalid network range", invalidNetworkRange}, + {"invalid domain string", invalidDomainString}, + } + + for _, test := range tests { + reader := strings.NewReader(test.data) + _, err := config.ReadConfig(reader) + if err == nil { + t.Errorf("%s: expected an error but got nil", test.name) + } + } +} + +type errReader struct{} + +func (r *errReader) Read(p []byte) (n int, err error) { + return 0, errors.New("read error") +} + +func TestReadConfigErrReader(t *testing.T) { + _, err := config.ReadConfig(&errReader{}) + if err == nil { + t.Error("expected an error but got nil") + } +} diff --git a/pkg/config/schema.go b/pkg/config/schema.go index 0bf1a9e..22f869a 100644 --- a/pkg/config/schema.go +++ b/pkg/config/schema.go @@ -33,7 +33,7 @@ func (n *CIDR) UnmarshalYAML(unmarshal func(interface{}) error) error { type AccessControlRule struct { Policy string `yaml:"policy" validate:"required,oneof=allow deny"` Networks []CIDR `yaml:"networks,omitempty" validate:"dive,cidr"` - Domains []string `yaml:"domains,omitempty" validate:"dive,fqdn"` + Domains []string `yaml:"domains,omitempty" validate:"dive,domain"` Methods []string `yaml:"methods,omitempty" validate:"dive,oneof=GET HEAD POST PUT DELETE PATCH"` Countries []string `yaml:"countries,omitempty" validate:"dive,iso3166_1_alpha2"` AutonomousSystems []uint32 `yaml:"autonomous_systems,omitempty" validate:"dive,numeric"`