diff --git a/README.md b/README.md index 4e90ab0..056d73d 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Secure comes with a variety of configuration options (Note: these are not the de // ... s := secure.New(secure.Options{ AllowedHosts: []string{"ssl.example.com"}, // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names. + AllowedHostsFunc: func() []string { return []string{"example\\.com", ".*\\.example\\.com" } // AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. This can be used in combination with the above AllowedHosts. AllowedHostsAreRegex: false, // AllowedHostsAreRegex determines, if the provided AllowedHosts slice contains valid regular expressions. Default is false. HostsProxyHeaders: []string{"X-Forwarded-Hosts"}, // HostsProxyHeaders is a set of header keys that may hold a proxied hostname value for the request. SSLRedirect: true, // If SSLRedirect is set to true, then only allow HTTPS requests. Default is false. @@ -101,6 +102,7 @@ s := secure.New() l := secure.New(secure.Options{ AllowedHosts: []string, + AllowedHostsFunc: nil, AllowedHostsAreRegex: false, HostsProxyHeaders: []string, SSLRedirect: false, diff --git a/secure.go b/secure.go index f8d3c8e..754e1cc 100644 --- a/secure.go +++ b/secure.go @@ -36,6 +36,9 @@ const ( // SSLHostFunc a type whose pointer is the type of field `SSLHostFunc` of `Options` struct type SSLHostFunc func(host string) (newHost string) +// AllowedHostsFunc a custom function type that returns a list of strings used in place of AllowedHosts list +type AllowedHostsFunc func() []string + func defaultBadHostHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "Bad Host", http.StatusInternalServerError) } @@ -90,6 +93,8 @@ type Options struct { CrossOriginOpenerPolicy string // SSLHost is the host name that is used to redirect http requests to https. Default is "", which indicates to use the same host. SSLHost string + // AllowedHostsFunc is a custom function that returns a list of fully qualified domain names that are allowed. If set, values will be appended to AllowedHosts + AllowedHostsFunc AllowedHostsFunc // AllowedHosts is a list of fully qualified domain names that are allowed. Default is empty list, which allows any and all host names. AllowedHosts []string // AllowedHostsAreRegex determines, if the provided slice contains valid regular expressions. If this flag is set to true, every request's @@ -290,7 +295,13 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He } // Allowed hosts check. - if len(s.opt.AllowedHosts) > 0 && !s.opt.IsDevelopment { + combinedAllowedHosts := s.opt.AllowedHosts + + if s.opt.AllowedHostsFunc != nil { + combinedAllowedHosts = append(combinedAllowedHosts, s.opt.AllowedHostsFunc()...) + } + + if len(combinedAllowedHosts) > 0 && !s.opt.IsDevelopment { isGoodHost := false if s.opt.AllowedHostsAreRegex { for _, allowedHost := range s.cRegexAllowedHosts { @@ -300,14 +311,13 @@ func (s *Secure) processRequest(w http.ResponseWriter, r *http.Request) (http.He } } } else { - for _, allowedHost := range s.opt.AllowedHosts { + for _, allowedHost := range combinedAllowedHosts { if strings.EqualFold(allowedHost, host) { isGoodHost = true break } } } - if !isGoodHost { s.badHostHandler.ServeHTTP(w, r) return nil, nil, fmt.Errorf("bad host name: %s", host) diff --git a/secure_test.go b/secure_test.go index 0d4b129..d54db03 100644 --- a/secure_test.go +++ b/secure_test.go @@ -1448,6 +1448,37 @@ func TestMultipleCustomSecureContextKeys(t *testing.T) { expect(t, s2Headers.Get(featurePolicyHeader), s2.opt.FeaturePolicy) } +func TestAllowHostsFunc(t *testing.T) { + s := New(Options{ + AllowedHostsFunc: func() []string { return []string{"www.example.com"} }, + }) + + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/foo", nil) + req.Host = "www.example.com" + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Body.String(), `bar`) +} + +func TestAllowHostsFuncWithAllowedHostsList(t *testing.T) { + s := New(Options{ + AllowedHosts: []string{"www.allow.com"}, + AllowedHostsFunc: func() []string { return []string{"www.allowfunc.com"} }, + }) + + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/foo", nil) + req.Host = "www.allow.com" + + s.Handler(myHandler).ServeHTTP(res, req) + + expect(t, res.Code, http.StatusOK) + expect(t, res.Body.String(), `bar`) +} + /* Test Helpers */ func expect(t *testing.T, a interface{}, b interface{}) { if a != b {