Skip to content

Commit

Permalink
feat: allow wildcard domains in config (#42)
Browse files Browse the repository at this point in the history
* feat: allow wildcard domains in config

* chore: rename file `loader.go` to `reader.go`
  • Loading branch information
danroc authored Nov 23, 2024
1 parent f298ee9 commit 27b792c
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 8 deletions.
14 changes: 12 additions & 2 deletions cmd/geoblock/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package main

import (
"bytes"
"os"
"time"

Expand Down Expand Up @@ -51,6 +52,15 @@ func autoUpdate(resolver *iprange.Resolver) {
}
}

// loadConfig reads the configuration file from the given path and returns it.
func loadConfig(path string) (*config.Configuration, error) {
file, err := os.ReadFile(path) // #nosec G304
if err != nil {
return nil, err
}
return config.ReadConfig(bytes.NewReader(file))
}

// hasChanged returns true if the two file infos are different. It only checks
// the size and the modification time.
func hasChanged(a, b os.FileInfo) bool {
Expand Down Expand Up @@ -78,7 +88,7 @@ func autoReload(engine *rules.Engine, path string) {
}
prevStat = stat

cfg, err := config.LoadConfig(path)
cfg, err := loadConfig(path)
if err != nil {
log.Errorf("Cannot read configuration file: %v", err)
continue
Expand Down Expand Up @@ -110,7 +120,7 @@ func main() {
configureLogger(options.logLevel)

log.Info("Loading configuration file")
cfg, err := config.LoadConfig(options.configPath)
cfg, err := loadConfig(options.configPath)
if err != nil {
log.Fatalf("Cannot read configuration file: %v", err)
}
Expand Down
26 changes: 21 additions & 5 deletions pkg/config/loader.go → pkg/config/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,27 @@
package config

import (
"os"
"io"
"regexp"

"github.com/go-playground/validator/v10"
"gopkg.in/yaml.v3"
)

// DomainNameRegex matches a valid domain name as per RFC 1035. It also allows
// labels to be a single `*` wildcard.
var domainNameRegex = regexp.MustCompile(
`^(\*|[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)(\.(\*|[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?))*$`,
)

func isDomainNameField(field validator.FieldLevel) bool {
domain, ok := field.Field().Interface().(string)
if !ok {
return false
}
return domainNameRegex.MatchString(domain)
}

// isCIDRField checks if the value of the given field is a valid CIDR.
func isCIDRField(field validator.FieldLevel) bool {
cidr, ok := field.Field().Interface().(CIDR)
Expand All @@ -26,7 +41,8 @@ func read(data []byte) (*Configuration, error) {
}

validate := validator.New()
validate.RegisterValidation("cidr", isCIDRField) // #nosec G104
validate.RegisterValidation("cidr", isCIDRField) // #nosec G104
validate.RegisterValidation("domain", isDomainNameField) // #nosec G104

if err := validate.Struct(config); err != nil {
return nil, err
Expand All @@ -35,9 +51,9 @@ func read(data []byte) (*Configuration, error) {
return &config, nil
}

// LoadConfig reads the configuration from the given file.
func LoadConfig(filename string) (*Configuration, error) {
data, err := os.ReadFile(filename) // #nosec G304
// ReadConfig reads the configuration from the given reader and returns it.
func ReadConfig(reader io.Reader) (*Configuration, error) {
data, err := io.ReadAll(reader)
if err != nil {
return nil, err
}
Expand Down
217 changes: 217 additions & 0 deletions pkg/config/reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package config_test

import (
"errors"
"net"
"reflect"
"strings"
"testing"

"github.com/danroc/geoblock/pkg/config"
)

const validConfig = `
access_control:
default_policy: allow
rules:
- networks:
- "10.0.0.0/8"
- "127.0.0.0/8"
domains:
- "example.com"
- "*.example.com"
methods:
- GET
- POST
countries:
- US
- FR
autonomous_systems:
- 1234
- 5678
policy: allow
- policy: deny
`

const invalidLeadingDot = `
access_control:
default_policy: allow
rules:
- domains:
- ".example.com"
policy: allow
`

const invalidWildcardLocation = `
access_control:
default_policy: allow
rules:
- domains:
- "*example.com"
policy: allow
`

const invalidDomainChar = `
access_control:
default_policy: allow
rules:
- domains:
- "example?.com"
policy: allow
`

const invalidLeadingDash = `
access_control:
default_policy: allow
rules:
- domains:
- "-example.com"
policy: allow
`

const invalidTrailingDash = `
access_control:
default_policy: allow
rules:
- domains:
- "example-.com"
policy: allow
`

const invalidDomainString = `
access_control:
default_policy: allow
rules:
- domains:
- false
policy: allow
`

const invalidNetworkString = `
access_control:
default_policy: allow
rules:
- networks:
- "invalid"
policy: allow
`

const invalidNetworkNumber = `
access_control:
default_policy: allow
rules:
- networks:
- 10
policy: allow
`

const invalidNetworkRange = `
access_control:
default_policy: allow
rules:
- networks:
- 300.300.300.300/50
policy: allow
`

func TestReadConfigValid(t *testing.T) {
tests := []struct {
name string
data string
expected *config.Configuration
}{
{
"valid configuration",
validConfig,
&config.Configuration{
AccessControl: config.AccessControl{
DefaultPolicy: "allow",
Rules: []config.AccessControlRule{
{
Policy: "allow",
Networks: []config.CIDR{
{
IPNet: &net.IPNet{
IP: net.IP{10, 0, 0, 0},
Mask: net.CIDRMask(8, 32),
},
},
{
IPNet: &net.IPNet{
IP: net.IP{127, 0, 0, 0},
Mask: net.CIDRMask(8, 32),
},
},
},
Domains: []string{
"example.com",
"*.example.com",
},
Methods: []string{"GET", "POST"},
Countries: []string{"US", "FR"},
AutonomousSystems: []uint32{1234, 5678},
},
{
Policy: "deny",
Networks: nil,
Domains: nil,
Methods: nil,
Countries: nil,
AutonomousSystems: nil,
},
},
},
},
},
}

for _, test := range tests {
reader := strings.NewReader(test.data)
cfg, err := config.ReadConfig(reader)
if err != nil {
t.Errorf("%s: unexpected error: %v", test.name, err)
}
if !reflect.DeepEqual(*cfg, *test.expected) {
t.Errorf("%s: expected %v, got %v", test.name, test.expected, cfg)
}
}
}

func TestReadConfigErr(t *testing.T) {
tests := []struct {
name string
data string
}{
{"invalid leading dot", invalidLeadingDot},
{"invalid wildcard location", invalidWildcardLocation},
{"invalid domain character", invalidDomainChar},
{"invalid leading dash", invalidLeadingDash},
{"invalid trailing dash", invalidTrailingDash},
{"invalid network string", invalidNetworkString},
{"invalid network number", invalidNetworkNumber},
{"invalid network range", invalidNetworkRange},
{"invalid domain string", invalidDomainString},
}

for _, test := range tests {
reader := strings.NewReader(test.data)
_, err := config.ReadConfig(reader)
if err == nil {
t.Errorf("%s: expected an error but got nil", test.name)
}
}
}

type errReader struct{}

func (r *errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("read error")
}

func TestReadConfigErrReader(t *testing.T) {
_, err := config.ReadConfig(&errReader{})
if err == nil {
t.Error("expected an error but got nil")
}
}
2 changes: 1 addition & 1 deletion pkg/config/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (n *CIDR) UnmarshalYAML(unmarshal func(interface{}) error) error {
type AccessControlRule struct {
Policy string `yaml:"policy" validate:"required,oneof=allow deny"`
Networks []CIDR `yaml:"networks,omitempty" validate:"dive,cidr"`
Domains []string `yaml:"domains,omitempty" validate:"dive,fqdn"`
Domains []string `yaml:"domains,omitempty" validate:"dive,domain"`
Methods []string `yaml:"methods,omitempty" validate:"dive,oneof=GET HEAD POST PUT DELETE PATCH"`
Countries []string `yaml:"countries,omitempty" validate:"dive,iso3166_1_alpha2"`
AutonomousSystems []uint32 `yaml:"autonomous_systems,omitempty" validate:"dive,numeric"`
Expand Down

0 comments on commit 27b792c

Please sign in to comment.