diff --git a/internal/integration/integration_test.go b/internal/integration/integration_test.go index 9aa2166..1c67578 100644 --- a/internal/integration/integration_test.go +++ b/internal/integration/integration_test.go @@ -164,13 +164,21 @@ func TestSubscriptionConnectionParams(t *testing.T) { opts []graphql.WebSocketOption }{ { - name: "authorized_user_gets_counter", + name: "connection_param_authorized_user_gets_counter", opts: []graphql.WebSocketOption{ graphql.WithConnectionParams(map[string]interface{}{ authKey: "authorized-user-token", }), }, }, + { + name: "header_authorized_user_gets_counter", + opts: []graphql.WebSocketOption{ + graphql.WithWebsocketHeader(http.Header{ + authKey: []string{"authorized-user-token"}, + }), + }, + }, { name: "unauthorized_user_gets_error", expectedError: "input: countAuthorized unauthorized\n", diff --git a/internal/integration/server/server.go b/internal/integration/server/server.go index e826882..5e56fe1 100644 --- a/internal/integration/server/server.go +++ b/internal/integration/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "net/http" "net/http/httptest" "strconv" "time" @@ -198,6 +199,20 @@ func getAuthToken(ctx context.Context) string { return "" } +func httpAuthMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + token := r.Header.Get(AuthKey) + if token != "" { + ctx = withAuthToken(ctx, token) + } + + r = r.WithContext(ctx) + handler.ServeHTTP(w, r) + }) +} + func RunServer() *httptest.Server { gqlgenServer := handler.New(NewExecutableSchema(Config{Resolvers: &resolver{}})) gqlgenServer.AddTransport(transport.POST{}) @@ -217,7 +232,9 @@ func RunServer() *httptest.Server { return next(ctx) }) - return httptest.NewServer(gqlgenServer) + server := httpAuthMiddleware(gqlgenServer) + + return httptest.NewServer(server) } type (