From d6ecdac35abfc54c05b3a1a70a60b5bbbe794d90 Mon Sep 17 00:00:00 2001 From: Roman Sarvarov Date: Mon, 20 Jan 2025 20:05:06 +0300 Subject: [PATCH] allow inject http client #33 --- inertia.go | 6 ++++-- option.go | 9 ++++++++- option_test.go | 33 +++++++++++++++++++++++++-------- response.go | 4 ++-- response_test.go | 8 ++++---- 5 files changed, 43 insertions(+), 17 deletions(-) diff --git a/inertia.go b/inertia.go index 92b7523..3027401 100644 --- a/inertia.go +++ b/inertia.go @@ -22,8 +22,8 @@ type Inertia struct { flash FlashProvider - ssrURL string - ssrHTTPClient *http.Client + ssrURL string + ssrClient *http.Client containerID string version string @@ -46,6 +46,7 @@ func New(rootTemplateHTML string, opts ...Option) (*Inertia, error) { sharedProps: make(Props), sharedTemplateData: make(TemplateData), sharedTemplateFuncs: make(TemplateFuncs), + ssrClient: &http.Client{}, } for _, opt := range opts { @@ -108,6 +109,7 @@ func NewFromTemplate(rootTemplate *template.Template, opts ...Option) (*Inertia, logger: log.New(io.Discard, "", 0), sharedProps: make(Props), sharedTemplateData: make(TemplateData), + ssrClient: &http.Client{}, } for _, opt := range opts { diff --git a/option.go b/option.go index d8c0f9a..acce1d0 100644 --- a/option.go +++ b/option.go @@ -89,7 +89,14 @@ func WithSSR(url ...string) Option { } i.ssrURL = u - i.ssrHTTPClient = &http.Client{} + return nil + } +} + +// WithSSRClient returns Option that will set Inertia's SSR http client. +func WithSSRClient(ssrClient *http.Client) Option { + return func(i *Inertia) error { + i.ssrClient = ssrClient return nil } } diff --git a/option_test.go b/option_test.go index 1bc6f8b..a3a15cb 100644 --- a/option_test.go +++ b/option_test.go @@ -3,9 +3,11 @@ package gonertia import ( "io" "log" + "net/http" "reflect" "testing" "testing/fstest" + "time" ) func TestWithVersion(t *testing.T) { @@ -225,10 +227,6 @@ func TestWithSSR(t *testing.T) { t.Fatalf("unexpected error: %s", err) } - if i.ssrHTTPClient == nil { - t.Fatal("ssr http client is nil") - } - if i.ssrURL != wantURL { t.Fatalf("ssrURL=%s, want=%s", i.containerID, wantURL) } @@ -247,16 +245,35 @@ func TestWithSSR(t *testing.T) { t.Fatalf("unexpected error: %s", err) } - if i.ssrHTTPClient == nil { - t.Fatal("ssr http client is nil") - } - if i.ssrURL != wantURL { t.Fatalf("ssrURL=%s, want=%s", i.containerID, wantURL) } }) } +func TestWithSSRClient(t *testing.T) { + t.Parallel() + + i := I() + + want := &http.Client{ + Transport: nil, + CheckRedirect: nil, + Jar: nil, + Timeout: 5 * time.Second, + } + + option := WithSSRClient(want) + + if err := option(i); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if i.ssrClient != want { + t.Fatal("ssr http client was not set") + } +} + func TestWithFlashProvider(t *testing.T) { t.Parallel() diff --git a/response.go b/response.go index f6aa615..030beb3 100644 --- a/response.go +++ b/response.go @@ -515,7 +515,7 @@ func (i *Inertia) buildInertiaHTML(page *page) (inertia, inertiaHead template.HT } func (i *Inertia) isSSREnabled() bool { - return i.ssrURL != "" && i.ssrHTTPClient != nil + return i.ssrURL != "" } // htmlContainerSSR will send request with json marshaled page payload to ssr render endpoint. @@ -529,7 +529,7 @@ func (i *Inertia) htmlContainerSSR(pageJSON []byte) (inertia, inertiaHead templa } setJSONRequest(req) - resp, err := i.ssrHTTPClient.Do(req) + resp, err := i.ssrClient.Do(req) if err != nil { return "", "", fmt.Errorf("execute http request: %w", err) } diff --git a/response_test.go b/response_test.go index a8719fa..e98a7b9 100644 --- a/response_test.go +++ b/response_test.go @@ -147,7 +147,7 @@ func TestInertia_Render(t *testing.T) { i.rootTemplateHTML = rootTemplate i.version = "f8v01xv4h4" i.ssrURL = ts.URL - i.ssrHTTPClient = ts.Client() + i.ssrClient = ts.Client() }) successRunner(t, i) @@ -171,7 +171,7 @@ func TestInertia_Render(t *testing.T) { i.rootTemplate = tmpl i.version = "f8v01xv4h4" i.ssrURL = ts.URL - i.ssrHTTPClient = ts.Client() + i.ssrClient = ts.Client() }) successRunner(t, i) @@ -189,7 +189,7 @@ func TestInertia_Render(t *testing.T) { i.rootTemplateHTML = rootTemplate i.version = "f8v01xv4h4" i.ssrURL = ts.URL - i.ssrHTTPClient = ts.Client() + i.ssrClient = ts.Client() }) errorRunner(t, i) @@ -214,7 +214,7 @@ func TestInertia_Render(t *testing.T) { i.rootTemplate = tmpl i.version = "f8v01xv4h4" i.ssrURL = ts.URL - i.ssrHTTPClient = ts.Client() + i.ssrClient = ts.Client() }) errorRunner(t, i)