Skip to content

Commit

Permalink
Merge pull request #285 from porters-xyz/cors-fix
Browse files Browse the repository at this point in the history
Properly managing CORS headers
  • Loading branch information
wtfsayo authored May 25, 2024
2 parents f3d3eee + 686c5c7 commit b802af7
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 17 deletions.
49 changes: 32 additions & 17 deletions gateway/plugins/origin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
log "log/slog"
"net/http"
"regexp"
"strings"

"porters/db"
"porters/proxy"
Expand Down Expand Up @@ -43,14 +43,7 @@ func (a *AllowedOriginFilter) HandleRequest(req *http.Request) error {
}

rules := a.getRulesForScope(ctx, app)
allow := (len(rules) == 0)

for _, rule := range rules {
if rule.MatchString(origin) {
allow = true
break
}
}
allow := a.matchesRules(origin, rules)

if !allow {
return proxy.NewHTTPError(http.StatusUnauthorized)
Expand All @@ -59,8 +52,26 @@ func (a *AllowedOriginFilter) HandleRequest(req *http.Request) error {
return nil
}

func (a *AllowedOriginFilter) getRulesForScope(ctx context.Context, app *db.App) []regexp.Regexp {
origins := make([]regexp.Regexp, 0)
func (a *AllowedOriginFilter) HandleResponse(resp *http.Response) error {
ctx := resp.Request.Context()
app := &db.App{
Id: proxy.PluckAppId(resp.Request),
}
err := app.Lookup(ctx)
if err != nil {
return nil // don't modify header
}

rules := a.getRulesForScope(ctx, app)
if len(rules) > 0 {
allowedOrigins := strings.Join(rules, ",")
resp.Header.Set("Access-Control-Allow-Origin", allowedOrigins)
}
return nil
}

func (a *AllowedOriginFilter) getRulesForScope(ctx context.Context, app *db.App) []string {
origins := make([]string, 0)
rules, err := app.Rules(ctx)
if err != nil {
log.Error("couldn't get rules", "app", app.HashId(), "err", err)
Expand All @@ -69,13 +80,17 @@ func (a *AllowedOriginFilter) getRulesForScope(ctx context.Context, app *db.App)
if rule.RuleType != ALLOWED_ORIGIN || !rule.Active {
continue
}
matcher, err := regexp.Compile(rule.Value)
if err != nil {
log.Error("error compiling origin regex", "regex", rule.Value, "err", err)
continue
}
origins = append(origins, *matcher)
origins = append(origins, rule.Value)
}
}
return origins
}

func (a *AllowedOriginFilter) matchesRules(origin string, rules []string) bool {
for _, rule := range rules {
if strings.EqualFold(rule, origin) {
return true
}
}
return false
}
27 changes: 27 additions & 0 deletions gateway/plugins/origin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package plugins

import (
"testing"
)

func TestAllowedOriginMatches(t *testing.T) {
want := true
origin := "http://test.com"
allowed := []string{"http://test2.com", "http://test.com"}
filter := &AllowedOriginFilter{}
got := filter.matchesRules(origin, allowed)
if want != got {
t.Fatal("origin doesn't match")
}
}

func TestAllowedOriginMismatch(t *testing.T) {
want := false
origin := "http://test3.com"
allowed := []string{"http://test2.com", "http://test.com"}
filter := &AllowedOriginFilter{}
got := filter.matchesRules(origin, allowed)
if want != got {
t.Fatal("origin doesn't match")
}
}
7 changes: 7 additions & 0 deletions gateway/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ func setupProxy(remote *url.URL) *httputil.ReverseProxy {

revProxy.ModifyResponse = func(resp *http.Response) error {
ctx := resp.Request.Context()
defaultHeaders(resp)

if common.Enabled(common.INSTRUMENT_ENABLED) {
instr, ok := common.FromContext(ctx, common.INSTRUMENT)
Expand Down Expand Up @@ -191,6 +192,12 @@ func setupContext(req *http.Request) {
*req = *req.WithContext(ctx)
}

// Add or remove headers on response
// Dealing with CORS mostly
func defaultHeaders(resp *http.Response) {
resp.Header.Set("Access-Control-Allow-Origin", "*")
}

func lookupPoktId(req *http.Request) (string, bool) {
ctx := req.Context()
name := PluckProductName(req)
Expand Down

0 comments on commit b802af7

Please sign in to comment.