diff --git a/helpers_test.go b/helpers_test.go index acf0b06..1c85497 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -82,7 +82,7 @@ func assertHeaderMissing(t *testing.T, w *httptest.ResponseRecorder, key string) t.Helper() if got := w.Header().Get(key); got != "" { - t.Fatalf("header=%s, want=%s", got, "") + t.Fatalf("unexpected header %s=%s, want=empty", key, got) } } @@ -133,6 +133,16 @@ func assertInertiaVary(t *testing.T, w *httptest.ResponseRecorder) { } } +func assertInertiaNotVary(t *testing.T, w *httptest.ResponseRecorder) { + t.Helper() + + gotVary := w.Header().Get("Vary") + + if gotVary != "" { + t.Fatal("unexpected Vary header found") + } +} + func assertHandlerServed(t *testing.T, handlers ...http.HandlerFunc) http.HandlerFunc { t.Helper() diff --git a/http.go b/http.go index 98cff1d..9e39235 100644 --- a/http.go +++ b/http.go @@ -5,22 +5,40 @@ import ( "strings" ) +const ( + headerInertia = "X-Inertia" + headerInertiaLocation = "X-Inertia-Location" + headerInertiaPartialData = "X-Inertia-Partial-Data" + headerInertiaPartialExcept = "X-Inertia-Partial-Except" + headerInertiaPartialComponent = "X-Inertia-Partial-Component" + headerInertiaVersion = "X-Inertia-Version" + headerVary = "Vary" + headerContentType = "Content-Type" +) + // IsInertiaRequest returns true if the request is an Inertia request. func IsInertiaRequest(r *http.Request) bool { - return r.Header.Get("X-Inertia") != "" + return r.Header.Get(headerInertia) != "" } func setInertiaInResponse(w http.ResponseWriter) { - w.Header().Set("X-Inertia", "true") + w.Header().Set(headerInertia, "true") +} + +func deleteInertiaInResponse(w http.ResponseWriter) { + w.Header().Del(headerInertia) } func setInertiaVaryInResponse(w http.ResponseWriter) { - w.Header().Set("Vary", "X-Inertia") + w.Header().Set(headerVary, headerInertia) +} + +func deleteVaryInResponse(w http.ResponseWriter) { + w.Header().Del(headerVary) } func setInertiaLocationInResponse(w http.ResponseWriter, url string) { - w.Header().Set("X-Inertia-Location", url) - setResponseStatus(w, http.StatusConflict) + w.Header().Set(headerInertiaLocation, url) } func setResponseStatus(w http.ResponseWriter, status int) { @@ -28,7 +46,7 @@ func setResponseStatus(w http.ResponseWriter, status int) { } func onlyFromRequest(r *http.Request) []string { - header := r.Header.Get("X-Inertia-Partial-Data") + header := r.Header.Get(headerInertiaPartialData) if header == "" { return nil } @@ -37,7 +55,7 @@ func onlyFromRequest(r *http.Request) []string { } func exceptFromRequest(r *http.Request) []string { - header := r.Header.Get("X-Inertia-Partial-Except") + header := r.Header.Get(headerInertiaPartialExcept) if header == "" { return nil } @@ -46,11 +64,11 @@ func exceptFromRequest(r *http.Request) []string { } func partialComponentFromRequest(r *http.Request) string { - return r.Header.Get("X-Inertia-Partial-Component") + return r.Header.Get(headerInertiaPartialComponent) } func inertiaVersionFromRequest(r *http.Request) string { - return r.Header.Get("X-Inertia-Version") + return r.Header.Get(headerInertiaVersion) } func redirectResponse(w http.ResponseWriter, r *http.Request, url string, status ...int) { @@ -58,15 +76,15 @@ func redirectResponse(w http.ResponseWriter, r *http.Request, url string, status } func setJSONResponse(w http.ResponseWriter) { - w.Header().Set("Content-Type", "application/json") + w.Header().Set(headerContentType, "application/json") } func setJSONRequest(r *http.Request) { - r.Header.Set("Content-Type", "application/json") + r.Header.Set(headerContentType, "application/json") } func setHTMLResponse(w http.ResponseWriter) { - w.Header().Set("Content-Type", "text/html") + w.Header().Set(headerContentType, "text/html") } func isSeeOtherRedirectMethod(method string) bool { diff --git a/middleware_test.go b/middleware_test.go index 64873ea..64d4b3c 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -80,9 +80,10 @@ func TestInertia_Middleware(t *testing.T) { asInertiaRequest(r) withInertiaVersion(r, "bar") - i.Middleware(assertHandlerServed(t, successJSONHandler)).ServeHTTP(w, r) + i.Middleware(assertHandlerServed(t, setInertiaResponseHandler, successJSONHandler)).ServeHTTP(w, r) - assertInertiaVary(t, w) + assertInertiaNotVary(t, w) + assertNotInertiaResponse(t, w) assertResponseStatusCode(t, w, http.StatusConflict) assertInertiaLocation(t, w, "/home") @@ -272,3 +273,7 @@ func setHeadersHandler(headers map[string]string) http.HandlerFunc { } } } + +func setInertiaResponseHandler(w http.ResponseWriter, _ *http.Request) { + setInertiaInResponse(w) +} diff --git a/response.go b/response.go index 8c6a2af..e99cbd1 100644 --- a/response.go +++ b/response.go @@ -63,6 +63,9 @@ func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, s if IsInertiaRequest(r) { setInertiaLocationInResponse(w, url) + deleteInertiaInResponse(w) + deleteVaryInResponse(w) + setResponseStatus(w, http.StatusConflict) return }