Skip to content

Commit

Permalink
store history clearing in session so it also works with redirects
Browse files Browse the repository at this point in the history
  • Loading branch information
romsar committed Oct 19, 2024
1 parent b40fce0 commit ff08ab9
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 93 deletions.
18 changes: 17 additions & 1 deletion helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func withValidationErrors(r *http.Request, errors ValidationErrors) {
*r = *r.WithContext(SetValidationErrors(r.Context(), errors))
}

func withClearHistory(r *http.Request) {
*r = *r.WithContext(ClearHistory(r.Context()))
}

func assertResponseStatusCode(t *testing.T, w *httptest.ResponseRecorder, want int) {
t.Helper()

Expand Down Expand Up @@ -194,9 +198,12 @@ func tmpFile(t *testing.T, content string) *os.File {
}

type flashProviderMock struct {
errors ValidationErrors
errors ValidationErrors
clearHistory bool
}

var _ FlashProvider = (*flashProviderMock)(nil)

func (p *flashProviderMock) FlashErrors(_ context.Context, errors ValidationErrors) error {
p.errors = errors
return nil
Expand All @@ -205,3 +212,12 @@ func (p *flashProviderMock) FlashErrors(_ context.Context, errors ValidationErro
func (p *flashProviderMock) GetErrors(_ context.Context) (ValidationErrors, error) {
return p.errors, nil
}

func (p *flashProviderMock) FlashClearHistory(_ context.Context) error {
p.clearHistory = true
return nil
}

func (p *flashProviderMock) ShouldClearHistory(_ context.Context) (bool, error) {
return p.clearHistory, nil
}
4 changes: 3 additions & 1 deletion inertia.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ type Logger interface {
Println(v ...any)
}

// FlashProvider defines an interface for flash data provider.
// FlashProvider defines an interface for a flash data provider.
type FlashProvider interface {
FlashErrors(ctx context.Context, errors ValidationErrors) error
GetErrors(ctx context.Context) (ValidationErrors, error)
ShouldClearHistory(ctx context.Context) (bool, error)
FlashClearHistory(ctx context.Context) error
}

// ShareProp adds passed prop to shared props.
Expand Down
27 changes: 24 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ func (i *Inertia) Middleware(next http.Handler) http.Handler {
// https://github.com/inertiajs/inertia-laravel/pull/404
setInertiaVaryInResponse(w)

// Resolve validation errors from the flash data provider.
r = i.resolveValidationErrors(r)
// Resolve validation errors and clear history from the flash data provider.
{
r = i.resolveValidationErrors(r)
r = i.resolveClearHistory(r)
}

if !IsInertiaRequest(r) {
next.ServeHTTP(w, r)
Expand Down Expand Up @@ -76,7 +79,7 @@ func (i *Inertia) resolveValidationErrors(r *http.Request) *http.Request {

validationErrors, err := i.flash.GetErrors(r.Context())
if err != nil {
i.logger.Printf("get validation errors from flash data provider error: %s", err)
i.logger.Printf("get validation errors from the flash data provider error: %s", err)
return r
}

Expand All @@ -87,6 +90,24 @@ func (i *Inertia) resolveValidationErrors(r *http.Request) *http.Request {
return r.WithContext(SetValidationErrors(r.Context(), validationErrors))
}

func (i *Inertia) resolveClearHistory(r *http.Request) *http.Request {
if i.flash == nil {
return r
}

clearHistory, err := i.flash.ShouldClearHistory(r.Context())
if err != nil {
i.logger.Printf("get clear history flag from the flash data provider error: %s", err)
return r
}

if clearHistory {
r = r.WithContext(ClearHistory(r.Context()))
}

return r
}

func (i *Inertia) copyWrapperResponse(dst http.ResponseWriter, src *inertiaResponseWrapper) {
i.copyWrapperHeaders(dst, src)
i.copyWrapperStatusCode(dst, src)
Expand Down
63 changes: 45 additions & 18 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,59 @@ func TestInertia_Middleware(t *testing.T) {
assertResponseStatusCode(t, w, http.StatusOK)
})

t.Run("resolve validation errors from flash data provider", func(t *testing.T) {
t.Run("flash", func(t *testing.T) {
t.Parallel()

w, r := requestMock(http.MethodGet, "/")
t.Run("validation errors", func(t *testing.T) {
t.Parallel()

want := ValidationErrors{
"foo": "baz",
"baz": "quz",
}
w, r := requestMock(http.MethodGet, "/")

flashProvider := &flashProviderMock{
errors: want,
}
want := ValidationErrors{
"foo": "baz",
"baz": "quz",
}

flashProvider := &flashProviderMock{
errors: want,
}

i := I(func(i *Inertia) {
i.flash = flashProvider
})

var got ValidationErrors
i.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = ValidationErrorsFromContext(r.Context())
})).ServeHTTP(w, r)

i := I(func(i *Inertia) {
i.flash = flashProvider
if !reflect.DeepEqual(got, want) {
t.Fatalf("validation errors=%#v, want=%#v", got, want)
}
})

var got ValidationErrors
i.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = ValidationErrorsFromContext(r.Context())
})).ServeHTTP(w, r)
t.Run("clear history", func(t *testing.T) {
t.Parallel()

if !reflect.DeepEqual(got, want) {
t.Fatalf("validation errors=%#v, want=%#v", got, want)
}
w, r := requestMock(http.MethodGet, "/")

flashProvider := &flashProviderMock{
clearHistory: true,
}

i := I(func(i *Inertia) {
i.flash = flashProvider
})

var got bool
i.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got = ClearHistoryFromContext(r.Context())
})).ServeHTTP(w, r)

if !got {
t.Fatalf("clear history=%v, want=true", got)
}
})
})
})

Expand Down
4 changes: 2 additions & 2 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ func WithSSR(url ...string) Option {
}

// WithFlashProvider returns Option that will set Inertia's flash data provider.
func WithFlashProvider(flashData FlashProvider) Option {
func WithFlashProvider(flash FlashProvider) Option {
return func(i *Inertia) error {
i.flash = flashData
i.flash = flash
return nil
}
}
Expand Down
5 changes: 1 addition & 4 deletions option_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gonertia

import (
"fmt"
"io"
"log"
"reflect"
Expand Down Expand Up @@ -92,8 +91,6 @@ func TestWithJSONMarshaller(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}

fmt.Println(got)

if got != want {
t.Fatalf("JSONMarshaller.Decode()=%s, want=%s", got, want)
}
Expand Down Expand Up @@ -249,7 +246,7 @@ func TestWithFlashProvider(t *testing.T) {
}

if i.flash != want {
t.Fatalf("flash provider=%v, want=%s", i.flash, want)
t.Fatalf("flash provider=%v, want=%v", i.flash, want)
}
}

Expand Down
32 changes: 27 additions & 5 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ type ValidationErrors map[string]any
// If request was made by Inertia - sets status to 409 and url will be in "X-Inertia-Location" header.
// Otherwise, it will do an HTTP redirect with specified status (default is 302 for GET, 303 for POST/PUT/PATCH).
func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, status ...int) {
i.flashValidationErrorsFromContext(r.Context())
i.flashContext(r.Context())

if IsInertiaRequest(r) {
setInertiaLocationInResponse(w, url)
Expand All @@ -179,10 +179,16 @@ func backURL(r *http.Request) string {

// Redirect creates plain redirect response.
func (i *Inertia) Redirect(w http.ResponseWriter, r *http.Request, url string, status ...int) {
i.flashValidationErrorsFromContext(r.Context())
i.flashContext(r.Context())

redirectResponse(w, r, url, status...)
}

func (i *Inertia) flashContext(ctx context.Context) {
i.flashValidationErrorsFromContext(ctx)
i.flashClearHistoryFromContext(ctx)
}

func (i *Inertia) flashValidationErrorsFromContext(ctx context.Context) {
if i.flash == nil {
return
Expand All @@ -199,6 +205,22 @@ func (i *Inertia) flashValidationErrorsFromContext(ctx context.Context) {
}
}

func (i *Inertia) flashClearHistoryFromContext(ctx context.Context) {
if i.flash == nil {
return
}

clearHistory := ClearHistoryFromContext(ctx)
if !clearHistory {
return
}

err := i.flash.FlashClearHistory(ctx)
if err != nil {
i.logger.Printf("cannot flash clear history: %s", err)
}
}

// Render returns response with Inertia data.
//
// If request was made by Inertia - it will return data in JSON format.
Expand Down Expand Up @@ -241,9 +263,9 @@ func (i *Inertia) buildPage(r *http.Request, component string, props Props) (*pa
deferredProps := resolveDeferredProps(r, component, props)
mergeProps := resolveMergeProps(r, props)

props, err := i.resolveProperties(r, component, props)
props, err := i.resolveProps(r, component, props)
if err != nil {
return nil, fmt.Errorf("prepare props: %w", err)
return nil, fmt.Errorf("resolve props: %w", err)
}

return &page{
Expand All @@ -258,7 +280,7 @@ func (i *Inertia) buildPage(r *http.Request, component string, props Props) (*pa
}, nil
}

func (i *Inertia) resolveProperties(r *http.Request, component string, props Props) (Props, error) {
func (i *Inertia) resolveProps(r *http.Request, component string, props Props) (Props, error) {
result := make(Props)

{
Expand Down
Loading

0 comments on commit ff08ab9

Please sign in to comment.