Skip to content

Commit

Permalink
Add ability to pass a custom allowed hosts function (#82)
Browse files Browse the repository at this point in the history
* add custom allowd hosts parameter to Options

* tests for AllowedHostsFunc

* add seperate checks for AllowedHostsFunc to support AllowdHosts list in tandem + tests

* clean up, just append the values from AllowedHostsFunc to the AllowedHosts list

* README update

* comment update

* readme typo fix

* readme typo fix

* Update README.md

Co-authored-by: Cory Jacobsen <[email protected]>

* combine the static a dynamic lists

Co-authored-by: Franklin De Los Santos <[email protected]>
Co-authored-by: Cory Jacobsen <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2022
1 parent cb6ee76 commit 764d6a2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Secure comes with a variety of configuration options (Note: these are not the de
// ...
s := secure.New(secure.Options{
AllowedHosts: []string{"ssl.example.com"}, // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
AllowedHostsFunc: func() []string { return []string{"example\\.com", ".*\\.example\\.com" } // AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. This can be used in combination with the above AllowedHosts.
AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. Default is false.
HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request.
SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false.
Expand Down Expand Up @@ -101,6 +102,7 @@ s := secure.New()

l := secure.New(secure.Options{
AllowedHosts: []string,
AllowedHostsFunc: nil,
AllowedHostsAreRegex: false,
HostsProxyHeaders: []string,
SSLRedirect: false,
Expand Down
16 changes: 13 additions & 3 deletions secure.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ const (
// SSLHostFunc a type whose pointer is the type of field `SSLHostFunc` of `Options` struct
type SSLHostFunc func(host string) (newHost string)

// AllowedHostsFunc a custom function type that returns a list of strings used in place of AllowedHosts list
type AllowedHostsFunc func() []string

func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Bad Host", http.StatusInternalServerError)
}
Expand Down Expand Up @@ -90,6 +93,8 @@ type Options struct {
CrossOriginOpenerPolicy string
// SSLHost is the host name that is used to redirect http requests to https. Default is "", which indicates to use the same host.
SSLHost string
// AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. If set, values will be appended to AllowedHosts
AllowedHostsFunc AllowedHostsFunc
// AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names.
AllowedHosts []string
// AllowedHostsAreRegex determines, if the provided slice contains valid regular expressions. If this flag is set to true, every request's
Expand Down Expand Up @@ -290,7 +295,13 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
}

// Allowed hosts check.
if len(s.opt.AllowedHosts) > 0 && !s.opt.IsDevelopment {
combinedAllowedHosts := s.opt.AllowedHosts

if s.opt.AllowedHostsFunc != nil {
combinedAllowedHosts = append(combinedAllowedHosts, s.opt.AllowedHostsFunc()...)
}

if len(combinedAllowedHosts) > 0 && !s.opt.IsDevelopment {
isGoodHost := false
if s.opt.AllowedHostsAreRegex {
for _, allowedHost := range s.cRegexAllowedHosts {
Expand All @@ -300,14 +311,13 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He
}
}
} else {
for _, allowedHost := range s.opt.AllowedHosts {
for _, allowedHost := range combinedAllowedHosts {
if strings.EqualFold(allowedHost, host) {
isGoodHost = true
break
}
}
}

if !isGoodHost {
s.badHostHandler.ServeHTTP(w, r)
return nil, nil, fmt.Errorf("bad host name: %s", host)
Expand Down
31 changes: 31 additions & 0 deletions secure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,37 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) {
expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy)
}

func TestAllowHostsFunc(t *testing.T) {
s := New(Options{
AllowedHostsFunc: func() []string { return []string{"www.example.com"} },
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.example.com"

s.Handler(myHandler).ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
}

func TestAllowHostsFuncWithAllowedHostsList(t *testing.T) {
s := New(Options{
AllowedHosts: []string{"www.allow.com"},
AllowedHostsFunc: func() []string { return []string{"www.allowfunc.com"} },
})

res := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/foo", nil)
req.Host = "www.allow.com"

s.Handler(myHandler).ServeHTTP(res, req)

expect(t, res.Code, http.StatusOK)
expect(t, res.Body.String(), `bar`)
}

/* Test Helpers */
func expect(t *testing.T, a interface{}, b interface{}) {
if a != b {
Expand Down

0 comments on commit 764d6a2

Please sign in to comment.