Skip to content

Commit

Permalink
allow inject http client #33
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Sarvarov committed Jan 20, 2025
1 parent a26e0f8 commit d6ecdac
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 17 deletions.
6 changes: 4 additions & 2 deletions inertia.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ type Inertia struct {

flash FlashProvider

ssrURL string
ssrHTTPClient *http.Client
ssrURL string
ssrClient *http.Client

containerID string
version string
Expand All @@ -46,6 +46,7 @@ func New(rootTemplateHTML string, opts ...Option) (*Inertia, error) {
sharedProps: make(Props),
sharedTemplateData: make(TemplateData),
sharedTemplateFuncs: make(TemplateFuncs),
ssrClient: &http.Client{},
}

for _, opt := range opts {
Expand Down Expand Up @@ -108,6 +109,7 @@ func NewFromTemplate(rootTemplate *template.Template, opts ...Option) (*Inertia,
logger: log.New(io.Discard, "", 0),
sharedProps: make(Props),
sharedTemplateData: make(TemplateData),
ssrClient: &http.Client{},
}

for _, opt := range opts {
Expand Down
9 changes: 8 additions & 1 deletion option.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ func WithSSR(url ...string) Option {
}

i.ssrURL = u
i.ssrHTTPClient = &http.Client{}
return nil
}
}

// WithSSRClient returns Option that will set Inertia's SSR http client.
func WithSSRClient(ssrClient *http.Client) Option {
return func(i *Inertia) error {
i.ssrClient = ssrClient
return nil
}
}
Expand Down
33 changes: 25 additions & 8 deletions option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package gonertia
import (
"io"
"log"
"net/http"
"reflect"
"testing"
"testing/fstest"
"time"
)

func TestWithVersion(t *testing.T) {
Expand Down Expand Up @@ -225,10 +227,6 @@ func TestWithSSR(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}

if i.ssrHTTPClient == nil {
t.Fatal("ssr http client is nil")
}

if i.ssrURL != wantURL {
t.Fatalf("ssrURL=%s, want=%s", i.containerID, wantURL)
}
Expand All @@ -247,16 +245,35 @@ func TestWithSSR(t *testing.T) {
t.Fatalf("unexpected error: %s", err)
}

if i.ssrHTTPClient == nil {
t.Fatal("ssr http client is nil")
}

if i.ssrURL != wantURL {
t.Fatalf("ssrURL=%s, want=%s", i.containerID, wantURL)
}
})
}

func TestWithSSRClient(t *testing.T) {
t.Parallel()

i := I()

want := &http.Client{
Transport: nil,
CheckRedirect: nil,
Jar: nil,
Timeout: 5 * time.Second,
}

option := WithSSRClient(want)

if err := option(i); err != nil {
t.Fatalf("unexpected error: %s", err)
}

if i.ssrClient != want {
t.Fatal("ssr http client was not set")
}
}

func TestWithFlashProvider(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 2 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ func (i *Inertia) buildInertiaHTML(page *page) (inertia, inertiaHead template.HT
}

func (i *Inertia) isSSREnabled() bool {
return i.ssrURL != "" && i.ssrHTTPClient != nil
return i.ssrURL != ""
}

// htmlContainerSSR will send request with json marshaled page payload to ssr render endpoint.
Expand All @@ -529,7 +529,7 @@ func (i *Inertia) htmlContainerSSR(pageJSON []byte) (inertia, inertiaHead templa
}
setJSONRequest(req)

resp, err := i.ssrHTTPClient.Do(req)
resp, err := i.ssrClient.Do(req)
if err != nil {
return "", "", fmt.Errorf("execute http request: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestInertia_Render(t *testing.T) {
i.rootTemplateHTML = rootTemplate
i.version = "f8v01xv4h4"
i.ssrURL = ts.URL
i.ssrHTTPClient = ts.Client()
i.ssrClient = ts.Client()
})

successRunner(t, i)
Expand All @@ -171,7 +171,7 @@ func TestInertia_Render(t *testing.T) {
i.rootTemplate = tmpl
i.version = "f8v01xv4h4"
i.ssrURL = ts.URL
i.ssrHTTPClient = ts.Client()
i.ssrClient = ts.Client()
})

successRunner(t, i)
Expand All @@ -189,7 +189,7 @@ func TestInertia_Render(t *testing.T) {
i.rootTemplateHTML = rootTemplate
i.version = "f8v01xv4h4"
i.ssrURL = ts.URL
i.ssrHTTPClient = ts.Client()
i.ssrClient = ts.Client()
})

errorRunner(t, i)
Expand All @@ -214,7 +214,7 @@ func TestInertia_Render(t *testing.T) {
i.rootTemplate = tmpl
i.version = "f8v01xv4h4"
i.ssrURL = ts.URL
i.ssrHTTPClient = ts.Client()
i.ssrClient = ts.Client()
})

errorRunner(t, i)
Expand Down

0 comments on commit d6ecdac

Please sign in to comment.