diff --git a/inertia.go b/inertia.go index fe07a02..a9a84d3 100644 --- a/inertia.go +++ b/inertia.go @@ -1,6 +1,7 @@ package gonertia import ( + "context" "fmt" "html/template" "io" @@ -21,6 +22,8 @@ type Inertia struct { ssrURL string ssrHTTPClient *http.Client + errorsStore errorsStore + containerID string version string jsonMarshaller JSONMarshaller @@ -87,6 +90,11 @@ type logger interface { Println(v ...any) } +type errorsStore interface { + Push(ctx context.Context, errors ValidationErrors) error + Pop(ctx context.Context) (ValidationErrors, error) +} + // ShareProp adds passed prop to shared props. func (i *Inertia) ShareProp(key string, val any) { i.sharedProps[key] = val diff --git a/option.go b/option.go index c9ef92c..a34454a 100644 --- a/option.go +++ b/option.go @@ -80,3 +80,11 @@ func WithSSR(url ...string) Option { return nil } } + +// WithErrorsStore returns Option that will set Inertia's errors store. +func WithErrorsStore(errorsStore errorsStore) Option { + return func(i *Inertia) error { + i.errorsStore = errorsStore + return nil + } +} diff --git a/props.go b/props.go index 019ae8b..7f771a6 100644 --- a/props.go +++ b/props.go @@ -1,6 +1,7 @@ package gonertia import ( + "context" "fmt" "net/http" ) @@ -45,12 +46,12 @@ type ValidationErrors map[string]any func (i *Inertia) prepareProps(r *http.Request, component string, props Props) (Props, error) { result := make(Props) - // Add validation errors from context. - ctxValidationErrors, err := ValidationErrorsFromContext(r.Context()) + // Add validation errors to the result. + validationErrors, err := i.resolveValidationErrors(r) if err != nil { - return nil, fmt.Errorf("getting validation errors from context: %w", err) + return nil, fmt.Errorf("resolve validation errors: %w", err) } - result["errors"] = AlwaysProp{ctxValidationErrors} + result["errors"] = AlwaysProp{validationErrors} // Add shared props to the result. for key, val := range i.sharedProps { @@ -112,6 +113,43 @@ func (i *Inertia) prepareProps(r *http.Request, component string, props Props) ( return result, nil } +func (i *Inertia) resolveValidationErrors(r *http.Request) (ValidationErrors, error) { + // Add validation errors from storage. + storageValidationErrors, err := i.restoreValidationErrors(r.Context()) + if err != nil { + return nil, fmt.Errorf("getting validation errors from context: %w", err) + } + + // ... and from context. + ctxValidationErrors, err := ValidationErrorsFromContext(r.Context()) + if err != nil { + return nil, fmt.Errorf("getting validation errors from context: %w", err) + } + + validationErrors := make(ValidationErrors) + for key, val := range storageValidationErrors { + validationErrors[key] = val + } + for key, val := range ctxValidationErrors { + validationErrors[key] = val + } + + return ctxValidationErrors, nil +} + +func (i *Inertia) restoreValidationErrors(ctx context.Context) (ValidationErrors, error) { + if i.errorsStore == nil { + return nil, nil + } + + storageValidationErrors, err := i.errorsStore.Pop(ctx) + if err != nil { + return nil, fmt.Errorf("errors store pop: %w", err) + } + + return storageValidationErrors, nil +} + func (i *Inertia) getOnlyAndExcept(r *http.Request, component string) (only, except map[string]struct{}) { // Partial reloads only work for visits made to the same page component. // diff --git a/response.go b/response.go index e43ebc8..e4219c6 100644 --- a/response.go +++ b/response.go @@ -218,6 +218,7 @@ func (i *Inertia) htmlContainer(pageJSON []byte) (inertia, _ template.HTML, _ er // Otherwise, it will do an HTTP redirect with specified status (default is 302 for GET, 303 for POST/PUT/PATCH). func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, status ...int) { if IsInertiaRequest(r) { + i.captureValidationErrors(r) setInertiaLocationInResponse(w, url) return } @@ -225,6 +226,22 @@ func (i *Inertia) Location(w http.ResponseWriter, r *http.Request, url string, s redirectResponse(w, r, url, status...) } +func (i *Inertia) captureValidationErrors(r *http.Request) { + if i.errorsStore == nil { + return + } + + validationErrors, err := ValidationErrorsFromContext(r.Context()) + if err != nil { + i.logger.Printf("invalid validation errors from context: %s", err) + } + + err = i.errorsStore.Push(r.Context(), validationErrors) + if err != nil { + i.logger.Printf("cannot push validation errors to storage: %s", err) + } +} + // Back creates redirect response to the previous url. func (i *Inertia) Back(w http.ResponseWriter, r *http.Request, status ...int) { i.Location(w, r, i.backURL(r), status...)