From ff08ab9ed3a6e6e6f84bc84473b48d4f0af72c85 Mon Sep 17 00:00:00 2001 From: Roman Sarvarov Date: Sun, 20 Oct 2024 02:36:30 +0300 Subject: [PATCH] store history clearing in session so it also works with redirects --- helpers_test.go | 18 +++- inertia.go | 4 +- middleware.go | 27 +++++- middleware_test.go | 63 ++++++++++---- option.go | 4 +- option_test.go | 5 +- response.go | 32 +++++-- response_test.go | 212 ++++++++++++++++++++++++++++++++------------- 8 files changed, 272 insertions(+), 93 deletions(-) diff --git a/helpers_test.go b/helpers_test.go index 8d9a4dc..b5fe025 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -66,6 +66,10 @@ func withValidationErrors(r *http.Request, errors ValidationErrors) { *r = *r.WithContext(SetValidationErrors(r.Context(), errors)) } +func withClearHistory(r *http.Request) { + *r = *r.WithContext(ClearHistory(r.Context())) +} + func assertResponseStatusCode(t *testing.T, w *httptest.ResponseRecorder, want int) { t.Helper() @@ -194,9 +198,12 @@ func tmpFile(t *testing.T, content string) *os.File { } type flashProviderMock struct { - errors ValidationErrors + errors ValidationErrors + clearHistory bool } +var _ FlashProvider = (*flashProviderMock)(nil) + func (p *flashProviderMock) FlashErrors(_ context.Context, errors ValidationErrors) error { p.errors = errors return nil @@ -205,3 +212,12 @@ func (p *flashProviderMock) FlashErrors(_ context.Context, errors ValidationErro func (p *flashProviderMock) GetErrors(_ context.Context) (ValidationErrors, error) { return p.errors, nil } + +func (p *flashProviderMock) FlashClearHistory(_ context.Context) error { + p.clearHistory = true + return nil +} + +func (p *flashProviderMock) ShouldClearHistory(_ context.Context) (bool, error) { + return p.clearHistory, nil +} diff --git a/inertia.go b/inertia.go index 2542c3f..dc58323 100644 --- a/inertia.go +++ b/inertia.go @@ -90,10 +90,12 @@ type Logger interface { Println(v ...any) } -// FlashProvider defines an interface for flash data provider. +// FlashProvider defines an interface for a flash data provider. type FlashProvider interface { FlashErrors(ctx context.Context, errors ValidationErrors) error GetErrors(ctx context.Context) (ValidationErrors, error) + ShouldClearHistory(ctx context.Context) (bool, error) + FlashClearHistory(ctx context.Context) error } // ShareProp adds passed prop to shared props. diff --git a/middleware.go b/middleware.go index 2448ed9..16353f4 100644 --- a/middleware.go +++ b/middleware.go @@ -17,8 +17,11 @@ func (i *Inertia) Middleware(next http.Handler) http.Handler { // https://github.com/inertiajs/inertia-laravel/pull/404 setInertiaVaryInResponse(w) - // Resolve validation errors from the flash data provider. - r = i.resolveValidationErrors(r) + // Resolve validation errors and clear history from the flash data provider. + { + r = i.resolveValidationErrors(r) + r = i.resolveClearHistory(r) + } if !IsInertiaRequest(r) { next.ServeHTTP(w, r) @@ -76,7 +79,7 @@ func (i *Inertia) resolveValidationErrors(r *http.Request) *http.Request { validationErrors, err := i.flash.GetErrors(r.Context()) if err != nil { - i.logger.Printf("get validation errors from flash data provider error: %s", err) + i.logger.Printf("get validation errors from the flash data provider error: %s", err) return r } @@ -87,6 +90,24 @@ func (i *Inertia) resolveValidationErrors(r *http.Request) *http.Request { return r.WithContext(SetValidationErrors(r.Context(), validationErrors)) } +func (i *Inertia) resolveClearHistory(r *http.Request) *http.Request { + if i.flash == nil { + return r + } + + clearHistory, err := i.flash.ShouldClearHistory(r.Context()) + if err != nil { + i.logger.Printf("get clear history flag from the flash data provider error: %s", err) + return r + } + + if clearHistory { + r = r.WithContext(ClearHistory(r.Context())) + } + + return r +} + func (i *Inertia) copyWrapperResponse(dst http.ResponseWriter, src *inertiaResponseWrapper) { i.copyWrapperHeaders(dst, src) i.copyWrapperStatusCode(dst, src) diff --git a/middleware_test.go b/middleware_test.go index 1ee56e1..09dd514 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -24,32 +24,59 @@ func TestInertia_Middleware(t *testing.T) { assertResponseStatusCode(t, w, http.StatusOK) }) - t.Run("resolve validation errors from flash data provider", func(t *testing.T) { + t.Run("flash", func(t *testing.T) { t.Parallel() - w, r := requestMock(http.MethodGet, "/") + t.Run("validation errors", func(t *testing.T) { + t.Parallel() - want := ValidationErrors{ - "foo": "baz", - "baz": "quz", - } + w, r := requestMock(http.MethodGet, "/") - flashProvider := &flashProviderMock{ - errors: want, - } + want := ValidationErrors{ + "foo": "baz", + "baz": "quz", + } + + flashProvider := &flashProviderMock{ + errors: want, + } + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + var got ValidationErrors + i.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = ValidationErrorsFromContext(r.Context()) + })).ServeHTTP(w, r) - i := I(func(i *Inertia) { - i.flash = flashProvider + if !reflect.DeepEqual(got, want) { + t.Fatalf("validation errors=%#v, want=%#v", got, want) + } }) - var got ValidationErrors - i.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - got = ValidationErrorsFromContext(r.Context()) - })).ServeHTTP(w, r) + t.Run("clear history", func(t *testing.T) { + t.Parallel() - if !reflect.DeepEqual(got, want) { - t.Fatalf("validation errors=%#v, want=%#v", got, want) - } + w, r := requestMock(http.MethodGet, "/") + + flashProvider := &flashProviderMock{ + clearHistory: true, + } + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + var got bool + i.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + got = ClearHistoryFromContext(r.Context()) + })).ServeHTTP(w, r) + + if !got { + t.Fatalf("clear history=%v, want=true", got) + } + }) }) }) diff --git a/option.go b/option.go index 8dd4451..9935dde 100644 --- a/option.go +++ b/option.go @@ -82,9 +82,9 @@ func WithSSR(url ...string) Option { } // WithFlashProvider returns Option that will set Inertia's flash data provider. -func WithFlashProvider(flashData FlashProvider) Option { +func WithFlashProvider(flash FlashProvider) Option { return func(i *Inertia) error { - i.flash = flashData + i.flash = flash return nil } } diff --git a/option_test.go b/option_test.go index 164a095..8be407c 100644 --- a/option_test.go +++ b/option_test.go @@ -1,7 +1,6 @@ package gonertia import ( - "fmt" "io" "log" "reflect" @@ -92,8 +91,6 @@ func TestWithJSONMarshaller(t *testing.T) { t.Fatalf("unexpected error: %s", err) } - fmt.Println(got) - if got != want { t.Fatalf("JSONMarshaller.Decode()=%s, want=%s", got, want) } @@ -249,7 +246,7 @@ func TestWithFlashProvider(t *testing.T) { } if i.flash != want { - t.Fatalf("flash provider=%v, want=%s", i.flash, want) + t.Fatalf("flash provider=%v, want=%v", i.flash, want) } } diff --git a/response.go b/response.go index 4fc276b..f6aa615 100644 --- a/response.go +++ b/response.go @@ -154,7 +154,7 @@ type ValidationErrors map[string]any // If request was made by Inertia - sets status to 409 and url will be in "X-Inertia-Location" header. // Otherwise, it will do an HTTP redirect with specified status (default is 302 for GET, 303 for POST/PUT/PATCH). func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, status ...int) { - i.flashValidationErrorsFromContext(r.Context()) + i.flashContext(r.Context()) if IsInertiaRequest(r) { setInertiaLocationInResponse(w, url) @@ -179,10 +179,16 @@ func backURL(r *http.Request) string { // Redirect creates plain redirect response. func (i *Inertia) Redirect(w http.ResponseWriter, r *http.Request, url string, status ...int) { - i.flashValidationErrorsFromContext(r.Context()) + i.flashContext(r.Context()) + redirectResponse(w, r, url, status...) } +func (i *Inertia) flashContext(ctx context.Context) { + i.flashValidationErrorsFromContext(ctx) + i.flashClearHistoryFromContext(ctx) +} + func (i *Inertia) flashValidationErrorsFromContext(ctx context.Context) { if i.flash == nil { return @@ -199,6 +205,22 @@ func (i *Inertia) flashValidationErrorsFromContext(ctx context.Context) { } } +func (i *Inertia) flashClearHistoryFromContext(ctx context.Context) { + if i.flash == nil { + return + } + + clearHistory := ClearHistoryFromContext(ctx) + if !clearHistory { + return + } + + err := i.flash.FlashClearHistory(ctx) + if err != nil { + i.logger.Printf("cannot flash clear history: %s", err) + } +} + // Render returns response with Inertia data. // // If request was made by Inertia - it will return data in JSON format. @@ -241,9 +263,9 @@ func (i *Inertia) buildPage(r *http.Request, component string, props Props) (*pa deferredProps := resolveDeferredProps(r, component, props) mergeProps := resolveMergeProps(r, props) - props, err := i.resolveProperties(r, component, props) + props, err := i.resolveProps(r, component, props) if err != nil { - return nil, fmt.Errorf("prepare props: %w", err) + return nil, fmt.Errorf("resolve props: %w", err) } return &page{ @@ -258,7 +280,7 @@ func (i *Inertia) buildPage(r *http.Request, component string, props Props) (*pa }, nil } -func (i *Inertia) resolveProperties(r *http.Request, component string, props Props) (Props, error) { +func (i *Inertia) resolveProps(r *http.Request, component string, props Props) (Props, error) { result := make(Props) { diff --git a/response_test.go b/response_test.go index d72bcbf..9f27776 100644 --- a/response_test.go +++ b/response_test.go @@ -706,56 +706,103 @@ func TestInertia_Location(t *testing.T) { assertInertiaLocation(t, w, wantInertiaLocation) }) - t.Run("flash validation errors", func(t *testing.T) { + t.Run("flash", func(t *testing.T) { t.Parallel() - t.Run("plain redirect", func(t *testing.T) { + t.Run("validation errors", func(t *testing.T) { t.Parallel() - w, r := requestMock(http.MethodGet, "/") + t.Run("plain redirect", func(t *testing.T) { + t.Parallel() - flashProvider := &flashProviderMock{} + w, r := requestMock(http.MethodGet, "/") - i := I(func(i *Inertia) { - i.flash = flashProvider + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + errors := ValidationErrors{ + "foo": "bar", + "baz": "quz", + } + + withValidationErrors(r, errors) + i.Location(w, r, "/foo") + + if !reflect.DeepEqual(flashProvider.errors, errors) { + t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) + } }) - errors := ValidationErrors{ - "foo": "bar", - "baz": "quz", - } + t.Run("inertia location", func(t *testing.T) { + t.Parallel() - withValidationErrors(r, errors) - i.Location(w, r, "/foo") + w, r := requestMock(http.MethodGet, "/") + asInertiaRequest(r) - if !reflect.DeepEqual(flashProvider.errors, errors) { - t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) - } + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + errors := ValidationErrors{ + "foo": "bar", + "baz": "quz", + } + + withValidationErrors(r, errors) + i.Location(w, r, "/foo", http.StatusMovedPermanently) + + if !reflect.DeepEqual(flashProvider.errors, errors) { + t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) + } + }) }) - t.Run("inertia location", func(t *testing.T) { + t.Run("clear history", func(t *testing.T) { t.Parallel() - w, r := requestMock(http.MethodGet, "/") - asInertiaRequest(r) + t.Run("plain redirect", func(t *testing.T) { + t.Parallel() - flashProvider := &flashProviderMock{} + w, r := requestMock(http.MethodGet, "/") - i := I(func(i *Inertia) { - i.flash = flashProvider + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + withClearHistory(r) + i.Location(w, r, "/foo") + + if !flashProvider.clearHistory { + t.Fatalf("got clear history=%v, want=true", flashProvider.clearHistory) + } }) - errors := ValidationErrors{ - "foo": "bar", - "baz": "quz", - } + t.Run("inertia location", func(t *testing.T) { + t.Parallel() - withValidationErrors(r, errors) - i.Location(w, r, "/foo", http.StatusMovedPermanently) + w, r := requestMock(http.MethodGet, "/") + asInertiaRequest(r) - if !reflect.DeepEqual(flashProvider.errors, errors) { - t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) - } + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + withClearHistory(r) + i.Location(w, r, "/foo", http.StatusMovedPermanently) + + if !flashProvider.clearHistory { + t.Fatalf("got clear history=%v, want=true", flashProvider.clearHistory) + } + }) }) }) } @@ -809,28 +856,51 @@ func TestInertia_Redirect(t *testing.T) { assertInertiaLocation(t, w, wantInertiaLocation) }) - t.Run("flash validation errors", func(t *testing.T) { + t.Run("flash", func(t *testing.T) { t.Parallel() - w, r := requestMock(http.MethodGet, "/") + t.Run("validation errors", func(t *testing.T) { + t.Parallel() + + w, r := requestMock(http.MethodGet, "/") + + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + errors := ValidationErrors{ + "foo": "bar", + "baz": "quz", + } - flashProvider := &flashProviderMock{} + withValidationErrors(r, errors) + i.Redirect(w, r, "https://example.com/foo") - i := I(func(i *Inertia) { - i.flash = flashProvider + if !reflect.DeepEqual(flashProvider.errors, errors) { + t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) + } }) - errors := ValidationErrors{ - "foo": "bar", - "baz": "quz", - } + t.Run("clear history", func(t *testing.T) { + t.Parallel() + + w, r := requestMock(http.MethodGet, "/") + + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) - withValidationErrors(r, errors) - i.Redirect(w, r, "https://example.com/foo") + withClearHistory(r) + i.Redirect(w, r, "https://example.com/foo") - if !reflect.DeepEqual(flashProvider.errors, errors) { - t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) - } + if !flashProvider.clearHistory { + t.Fatalf("got clear history=%v, want=true", flashProvider.clearHistory) + } + }) }) } @@ -886,29 +956,53 @@ func TestInertia_Back(t *testing.T) { assertInertiaLocation(t, w, wantInertiaLocation) }) - t.Run("flash validation errors", func(t *testing.T) { + t.Run("flash", func(t *testing.T) { t.Parallel() - w, r := requestMock(http.MethodGet, "/") - r.Header.Set("Referer", "https://example.com/foo") + t.Run("validation errors", func(t *testing.T) { + t.Parallel() + + w, r := requestMock(http.MethodGet, "/") + r.Header.Set("Referer", "https://example.com/foo") + + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + errors := ValidationErrors{ + "foo": "bar", + "baz": "quz", + } - flashProvider := &flashProviderMock{} + withValidationErrors(r, errors) + i.Back(w, r) - i := I(func(i *Inertia) { - i.flash = flashProvider + if !reflect.DeepEqual(flashProvider.errors, errors) { + t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) + } }) - errors := ValidationErrors{ - "foo": "bar", - "baz": "quz", - } + t.Run("clear history", func(t *testing.T) { + t.Parallel() - withValidationErrors(r, errors) - i.Back(w, r) + w, r := requestMock(http.MethodGet, "/") + r.Header.Set("Referer", "https://example.com/foo") + + flashProvider := &flashProviderMock{} + + i := I(func(i *Inertia) { + i.flash = flashProvider + }) + + withClearHistory(r) + i.Back(w, r) - if !reflect.DeepEqual(flashProvider.errors, errors) { - t.Fatalf("got validation errors=%#v, want=%#v", flashProvider.errors, errors) - } + if !flashProvider.clearHistory { + t.Fatalf("got clear history=%v, want=true", flashProvider.clearHistory) + } + }) }) }