Skip to content

Commit

Permalink
Test middlewares using gorilla mux
Browse files Browse the repository at this point in the history
This commit confirms that our middlewares implement the gorilla
Middleware interface.
  • Loading branch information
sevein committed Nov 29, 2024
1 parent f9a6b25 commit 78e8bb8
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 12 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/go-logr/logr v1.4.1
github.com/go-logr/zapr v1.3.0
github.com/google/go-cmp v0.6.0
github.com/gorilla/mux v1.8.1
go.temporal.io/api v1.29.2
go.temporal.io/sdk v1.26.0
go.uber.org/mock v0.4.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfF
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA=
github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI=
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1 h1:/c3QmbOGMGTOumP2iT/rCwB7b0QDGLKzqOmktBjT+Is=
Expand Down
40 changes: 40 additions & 0 deletions middleware/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package middleware_test

import (
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/go-logr/logr/funcr"
"github.com/gorilla/mux"
"go.artefactual.dev/tools/middleware"
"gotest.tools/v3/assert"
)

func TestMiddlewares(t *testing.T) {
t.Parallel()

logged := ""
logger := funcr.New(func(prefix, args string) { logged += prefix + args }, funcr.Options{})

panicker := func(w http.ResponseWriter, r *http.Request) { panic("opsie") }
router := mux.NewRouter()
router.HandleFunc("/", panicker).Methods("GET")
router.Use(
middleware.Recover(logger),
middleware.WriteTimeout(0),
middleware.VersionHeader("", "v1.2.3"),
)

rw := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/", nil)

router.ServeHTTP(rw, req)

// Recover logs the panic error and handles the response.
assert.Assert(t, strings.Contains(logged, "Panic error recovered."))

// VersionHeader injects a header into the response.
assert.Equal(t, rw.Header().Get("X-Version"), "v1.2.3")
}
1 change: 1 addition & 0 deletions middleware/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/go-logr/logr"
)

// Recover from panics and log the error.
func Recover(logger logr.Logger) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
22 changes: 12 additions & 10 deletions middleware/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (

// WriteTimeout sets the write deadline for writing the response. The zero value
// means no timeout.
func WriteTimeout(h http.Handler, timeout time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rc := http.NewResponseController(w)
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
}
_ = rc.SetWriteDeadline(deadline)
h.ServeHTTP(w, r)
})
func WriteTimeout(timeout time.Duration) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rc := http.NewResponseController(w)
var deadline time.Time
if timeout != 0 {
deadline = time.Now().Add(timeout)
}
_ = rc.SetWriteDeadline(deadline)
h.ServeHTTP(w, r)
})
}
}
4 changes: 2 additions & 2 deletions middleware/timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func TestWriteTimeout(t *testing.T) {
t.Run("Sets a write timeout", func(t *testing.T) {
t.Parallel()

ts := httptest.NewServer(middleware.WriteTimeout(h, time.Microsecond))
ts := httptest.NewServer(middleware.WriteTimeout(time.Microsecond)(h))
defer ts.Close()

_, err := ts.Client().Get(ts.URL)
Expand All @@ -33,7 +33,7 @@ func TestWriteTimeout(t *testing.T) {
t.Run("Sets an unlimited write timeout", func(t *testing.T) {
t.Parallel()

ts := httptest.NewServer(middleware.WriteTimeout(h, 0))
ts := httptest.NewServer(middleware.WriteTimeout(0)(h))
defer ts.Close()

resp, err := ts.Client().Get(ts.URL)
Expand Down
2 changes: 2 additions & 0 deletions middleware/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package middleware

import "net/http"

// VersionHeader sets a version header on the response. If name is empty, it
// defaults to "X-Version".
func VersionHeader(name, version string) func(http.Handler) http.Handler {
if name == "" {
name = "X-Version"
Expand Down

0 comments on commit 78e8bb8

Please sign in to comment.