From 217ca8ff6e45a1ffe8c4e0189c2b6c4126794256 Mon Sep 17 00:00:00 2001 From: Roman Sarvarov Date: Thu, 18 Jul 2024 23:16:08 +0300 Subject: [PATCH 1/3] redirect refactoring #8 --- helpers_test.go | 2 +- response.go | 27 ++++++---- response_test.go | 132 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 126 insertions(+), 35 deletions(-) diff --git a/helpers_test.go b/helpers_test.go index 986aee6..512e9ab 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -70,7 +70,7 @@ func assertHeader(t *testing.T, w *httptest.ResponseRecorder, key, want string) t.Helper() if got := w.Header().Get(key); got != want { - t.Fatalf("header=%s, want=%s", got, want) + t.Fatalf("header %s=%s, want=%s", strings.ToLower(key), got, want) } } diff --git a/response.go b/response.go index e56a4f5..b827ee0 100644 --- a/response.go +++ b/response.go @@ -69,6 +69,22 @@ func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, s redirectResponse(w, r, url, status...) } +// Back creates redirect response to the previous url. +func (i *Inertia) Back(w http.ResponseWriter, r *http.Request, status ...int) { + i.Redirect(w, r, i.backURL(r), status...) +} + +func (i *Inertia) backURL(r *http.Request) string { + // At the moment, it based only on the "Referer" HTTP header. + return refererFromRequest(r) +} + +// Redirect creates plain redirect response. +func (i *Inertia) Redirect(w http.ResponseWriter, r *http.Request, url string, status ...int) { + i.flashValidationErrorsFromContext(r.Context()) + redirectResponse(w, r, url, status...) +} + func (i *Inertia) flashValidationErrorsFromContext(ctx context.Context) { if i.flash == nil { return @@ -90,16 +106,6 @@ func (i *Inertia) flashValidationErrorsFromContext(ctx context.Context) { } } -// Back creates redirect response to the previous url. -func (i *Inertia) Back(w http.ResponseWriter, r *http.Request, status ...int) { - i.Location(w, r, i.backURL(r), status...) -} - -func (i *Inertia) backURL(r *http.Request) string { - // At the moment, it based only on the "Referer" HTTP header. - return refererFromRequest(r) -} - // Render returns response with Inertia data. // // If request was made by Inertia - it will return data in JSON format. @@ -153,7 +159,6 @@ func (i *Inertia) prepareProps(r *http.Request, component string, props Props) ( { // Add validation errors from context to the result. - // Get validation errors from context. validationErrors, err := ValidationErrorsFromContext(r.Context()) if err != nil { return nil, fmt.Errorf("get validation errors from context: %w", err) diff --git a/response_test.go b/response_test.go index e9d4cba..d8f1d3b 100644 --- a/response_test.go +++ b/response_test.go @@ -497,14 +497,8 @@ func TestInertia_Location(t *testing.T) { "baz": "quz", } - wantStatus := http.StatusFound - wantLocation := "/foo" - r = r.WithContext(WithValidationErrors(r.Context(), errors)) - i.Location(w, r, wantLocation) - - assertResponseStatusCode(t, w, wantStatus) - assertLocation(t, w, wantLocation) + i.Location(w, r, "/foo") if !reflect.DeepEqual(flashProvider.errors, errors) { t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) @@ -517,9 +511,6 @@ func TestInertia_Location(t *testing.T) { w, r := requestMock(http.MethodGet, "/") asInertiaRequest(r) - wantLocation := "" - wantInertiaLocation := "/foo" - flashProvider := &flashProviderMock{} i := I(func(i *Inertia) { @@ -532,11 +523,7 @@ func TestInertia_Location(t *testing.T) { } r = r.WithContext(WithValidationErrors(r.Context(), errors)) - i.Location(w, r, wantInertiaLocation, http.StatusMovedPermanently) - - assertLocation(t, w, wantLocation) - assertResponseStatusCode(t, w, http.StatusConflict) - assertInertiaLocation(t, w, wantInertiaLocation) + i.Location(w, r, "/foo", http.StatusMovedPermanently) if !reflect.DeepEqual(flashProvider.errors, errors) { t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) @@ -545,10 +532,84 @@ func TestInertia_Location(t *testing.T) { }) } +func TestInertia_Redirect(t *testing.T) { + t.Parallel() + + t.Run("with default status", func(t *testing.T) { + t.Parallel() + + wantStatus := http.StatusFound + wantLocation := "https://example.com/foo" + + w, r := requestMock(http.MethodGet, "/") + + i := I() + + i.Redirect(w, r, wantLocation) + + assertResponseStatusCode(t, w, wantStatus) + assertLocation(t, w, wantLocation) + }) + + t.Run("with specified status", func(t *testing.T) { + t.Parallel() + + wantStatus := http.StatusMovedPermanently + wantLocation := "https://example.com/foo" + + w, r := requestMock(http.MethodGet, "/") + + I().Redirect(w, r, wantLocation, wantStatus) + + assertResponseStatusCode(t, w, wantStatus) + assertLocation(t, w, wantLocation) + }) + + t.Run("inertia request", func(t *testing.T) { + t.Parallel() + + wantLocation := "https://example.com/foo" + wantInertiaLocation := "" + + w, r := requestMock(http.MethodGet, "/") + asInertiaRequest(r) + + I().Redirect(w, r, wantLocation, http.StatusMovedPermanently) + + assertLocation(t, w, wantLocation) + assertResponseStatusCode(t, w, http.StatusMovedPermanently) + assertInertiaLocation(t, w, wantInertiaLocation) + }) + + t.Run("flash validation errors", func(t *testing.T) { + t.Parallel() + + w, r := requestMock(http.MethodGet, "/") + + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + errors := ValidationErrors{ + "foo": "bar", + "baz": "quz", + } + + r = r.WithContext(WithValidationErrors(r.Context(), errors)) + i.Redirect(w, r, "https://example.com/foo") + + if !reflect.DeepEqual(flashProvider.errors, errors) { + t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) + } + }) +} + func TestInertia_Back(t *testing.T) { t.Parallel() - t.Run("plain redirect with default status", func(t *testing.T) { + t.Run("with default status", func(t *testing.T) { t.Parallel() wantStatus := http.StatusFound @@ -565,7 +626,7 @@ func TestInertia_Back(t *testing.T) { assertLocation(t, w, wantLocation) }) - t.Run("plain redirect with specified status", func(t *testing.T) { + t.Run("with specified status", func(t *testing.T) { t.Parallel() wantStatus := http.StatusMovedPermanently @@ -574,28 +635,53 @@ func TestInertia_Back(t *testing.T) { w, r := requestMock(http.MethodGet, "/") r.Header.Set("Referer", wantLocation) - I().Location(w, r, wantLocation, wantStatus) + I().Back(w, r, wantStatus) assertResponseStatusCode(t, w, wantStatus) assertLocation(t, w, wantLocation) }) - t.Run("inertia location", func(t *testing.T) { + t.Run("inertia request", func(t *testing.T) { t.Parallel() - wantLocation := "" - wantInertiaLocation := "https://example.com/foo" + wantLocation := "https://example.com/foo" + wantInertiaLocation := "" w, r := requestMock(http.MethodGet, "/") r.Header.Set("Referer", wantLocation) asInertiaRequest(r) - I().Location(w, r, wantInertiaLocation, http.StatusMovedPermanently) + I().Back(w, r, http.StatusMovedPermanently) assertLocation(t, w, wantLocation) - assertResponseStatusCode(t, w, http.StatusConflict) + assertResponseStatusCode(t, w, http.StatusMovedPermanently) assertInertiaLocation(t, w, wantInertiaLocation) }) + + t.Run("flash validation errors", func(t *testing.T) { + t.Parallel() + + w, r := requestMock(http.MethodGet, "/") + r.Header.Set("Referer", "https://example.com/foo") + + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + errors := ValidationErrors{ + "foo": "bar", + "baz": "quz", + } + + r = r.WithContext(WithValidationErrors(r.Context(), errors)) + i.Back(w, r) + + if !reflect.DeepEqual(flashProvider.errors, errors) { + t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) + } + }) } func assertRootTemplateSuccess(t *testing.T, i *Inertia) { From 7aa976b26a78f511a2f1e73be7020a474936d9fe Mon Sep 17 00:00:00 2001 From: Roman Sarvarov Date: Thu, 18 Jul 2024 23:17:13 +0300 Subject: [PATCH 2/3] fix typo --- response.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/response.go b/response.go index b827ee0..5e2d035 100644 --- a/response.go +++ b/response.go @@ -69,12 +69,12 @@ func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, s redirectResponse(w, r, url, status...) } -// Back creates redirect response to the previous url. +// Back creates plain redirect response to the previous url. func (i *Inertia) Back(w http.ResponseWriter, r *http.Request, status ...int) { - i.Redirect(w, r, i.backURL(r), status...) + i.Redirect(w, r, backURL(r), status...) } -func (i *Inertia) backURL(r *http.Request) string { +func backURL(r *http.Request) string { // At the moment, it based only on the "Referer" HTTP header. return refererFromRequest(r) } From fe94854e5d717ae43f0101d921632697df79167e Mon Sep 17 00:00:00 2001 From: Roman Sarvarov Date: Thu, 18 Jul 2024 23:23:29 +0300 Subject: [PATCH 3/3] update README.md --- README.md | 9 ++------- middleware.go | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 45d540b..353e6dd 100644 --- a/README.md +++ b/README.md @@ -140,13 +140,8 @@ i.Render(w, r, "Some/Page", props) #### Redirects ([learn more](https://inertiajs.com/redirects)) ```go -func homeHandler(i *inertia.Inertia) http.Handler { - fn := func(w http.ResponseWriter, r *http.Request) { - i.Location(w, r, "/some-url") - } - - return http.HandlerFunc(fn) -} +i.Redirect(w, r, "https://example.com") // plain redirect +i.Location(w, r, "https://example.com") // external redirect ``` NOTES: diff --git a/middleware.go b/middleware.go index 653ed80..a53772b 100644 --- a/middleware.go +++ b/middleware.go @@ -56,7 +56,7 @@ func (i *Inertia) Middleware(next http.Handler) http.Handler { // Determines what to do when an Inertia action returned empty response. // By default, we will redirect the user back to where he came from. if w2.StatusCode() == http.StatusOK && w2.IsEmpty() { - backURL := i.backURL(r) + backURL := backURL(r) if backURL != "" { setInertiaLocationInResponse(w2, backURL)