From 401b5e4779cccbd03aa6a1cf6519578b1c5e4260 Mon Sep 17 00:00:00 2001 From: dylanhitt Date: Fri, 20 Dec 2024 12:43:17 -0500 Subject: [PATCH] chore: add tests for use of WithListener --- serve.go | 5 +---- serve_test.go | 44 +++++++++++++++++++++++++++++++++++--------- server_test.go | 17 +++++++++++++++++ 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/serve.go b/serve.go index 6eb55260..657d5994 100644 --- a/serve.go +++ b/serve.go @@ -53,11 +53,8 @@ func (s *Server) setupDefaultListener() error { return nil } listener, err := net.Listen("tcp", s.Addr) - if err != nil { - return err - } s.listener = listener - return nil + return err } func (s *Server) printStartupMessage() { diff --git a/serve_test.go b/serve_test.go index 18540c9c..d2cd7eca 100644 --- a/serve_test.go +++ b/serve_test.go @@ -390,13 +390,7 @@ func TestIni(t *testing.T) { } func TestServer_Run(t *testing.T) { - // This is not a standard test, it is here to ensure that the server can run. - // Please do not run this kind of test for your controllers, it is NOT unit testing. - t.Run("can run server", func(t *testing.T) { - s := NewServer( - WithoutLogger(), - ) - + runServer := func(s *Server) (*Server, func()) { Get(s, "/test", func(ctx ContextNoBody) (string, error) { return "OK", nil }) @@ -404,13 +398,21 @@ func TestServer_Run(t *testing.T) { go func() { s.Run() }() - defer func() { // stop our test server when we are done + return s, func() { // stop our test server when we are done ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) if err := s.Server.Shutdown(ctx); err != nil { t.Log(err) } cancel() - }() + } + } + // This is not a standard test, it is here to ensure that the server can run. + // Please do not run this kind of test for your controllers, it is NOT unit testing. + t.Run("can run server", func(t *testing.T) { + s, shutdown := runServer(NewServer( + WithoutLogger(), + )) + defer shutdown() require.Eventually(t, func() bool { req := httptest.NewRequest("GET", "/test", nil) @@ -420,6 +422,30 @@ func TestServer_Run(t *testing.T) { return w.Body.String() == `OK` }, 5*time.Second, 500*time.Millisecond) }) + + t.Run("can run server WithListener", func(t *testing.T) { + listener, err := net.Listen("tcp", ":8080") + require.NoError(t, err) + s, shutdown := runServer(NewServer( + WithListener(listener), + )) + defer shutdown() + + require.Eventually(t, func() bool { + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + s.Mux.ServeHTTP(w, req) + + return w.Body.String() == `OK` + }, 5*time.Second, 500*time.Millisecond) + }) + + t.Run("invalid address", func(t *testing.T) { + s := NewServer( + WithAddr("----:nope"), + ) + require.Error(t, s.Run()) + }) } func TestServer_RunTLS(t *testing.T) { diff --git a/server_test.go b/server_test.go index 45dd8bcd..895fcb8a 100644 --- a/server_test.go +++ b/server_test.go @@ -5,6 +5,7 @@ import ( "html/template" "io" "log/slog" + "net" "net/http" "net/http/httptest" "testing" @@ -339,6 +340,22 @@ func TestWithRequestContentType(t *testing.T) { }) } +func TestWithListener(t *testing.T) { + t.Run("with custom listener", func(t *testing.T) { + listener, err := net.Listen("tcp", ":8080") + require.NoError(t, err) + s := NewServer( + WithListener(listener), + ) + require.NotNil(t, s.listener) + }) + + t.Run("no custom listener", func(t *testing.T) { + s := NewServer() + require.Nil(t, s.listener) + }) +} + func TestCustomSerialization(t *testing.T) { s := NewServer( WithSerializer(func(w http.ResponseWriter, r *http.Request, a any) error {