Skip to content

Commit

Permalink
Merge pull request #14 from toshi0607/feature/add-bucket-conf
Browse files Browse the repository at this point in the history
Add bucket conf
  • Loading branch information
toshi0607 authored Aug 13, 2022
2 parents b3a48ac + faef36d commit ca1265c
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 68 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ chi-prometheus is used as a middleware. It also supports both a default registry
r.Get("/healthz", [YOUR HandlerFunc])
```

### Configuration

Latency histogram bucket is configurable with `CHI_PROMETHEUS_LATENCY_BUCKETS`. Default values are `300, 1200, 5000` (milliseconds).

You can override them as follows;

```shell
# comma separated string value
CHI_PROMETHEUS_LATENCY_BUCKETS="100,200,300,400"
```

## Install

```console
Expand Down
30 changes: 26 additions & 4 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package chiprometheus

import (
"net/http"
"os"
"strconv"
"strings"
"time"

"github.com/go-chi/chi/v5"
Expand All @@ -10,12 +13,14 @@ import (
)

var (
defaultBuckets = []float64{300, 1200, 5000}
bucketsConfig = []float64{300, 1200, 5000}
)

const (
RequestsCollectorName = "chi_requests_total"
LatencyCollectorName = "chi_request_duration_milliseconds"
// EnvChiPrometheusLatencyBuckets represents an environment variable, which is formatted like "100,200,300,400" as string
EnvChiPrometheusLatencyBuckets = "CHI_PROMETHEUS_LATENCY_BUCKETS"
RequestsCollectorName = "chi_requests_total"
LatencyCollectorName = "chi_request_duration_milliseconds"
)

// Middleware is a handler that exposes prometheus metrics for the number of requests,
Expand All @@ -25,8 +30,25 @@ type Middleware struct {
latency *prometheus.HistogramVec
}

func setBucket() {
var buckets []float64
conf, ok := os.LookupEnv(EnvChiPrometheusLatencyBuckets)
if ok {
for _, v := range strings.Split(conf, ",") {
f64v, err := strconv.ParseFloat(v, 64)
if err != nil {
panic(err)
}
buckets = append(buckets, f64v)
}
bucketsConfig = buckets
}
}

// New returns a new prometheus middleware for the provided service name.
func New(name string) *Middleware {
setBucket()

var m Middleware
m.requests = prometheus.NewCounterVec(
prometheus.CounterOpts{
Expand All @@ -39,7 +61,7 @@ func New(name string) *Middleware {
Name: LatencyCollectorName,
Help: "Time spent on the request partitioned by status code, method and HTTP path.",
ConstLabels: prometheus.Labels{"service": name},
Buckets: defaultBuckets,
Buckets: bucketsConfig,
}, []string{"code", "method", "path"})

return &m
Expand Down
195 changes: 131 additions & 64 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
Expand All @@ -20,7 +21,7 @@ import (
const testHost = "http://localhost"

func TestMiddleware_MustRegisterDefault(t *testing.T) {
t.Run("must panic without collectors", func(t *testing.T) {
t.Run("without collectors", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("must have panicked")
Expand All @@ -30,7 +31,7 @@ func TestMiddleware_MustRegisterDefault(t *testing.T) {
m.MustRegisterDefault()
})

t.Run("must not panic with collectors", func(t *testing.T) {
t.Run("with collectors", func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("must not have panicked")
Expand All @@ -56,11 +57,39 @@ func TestMiddleware_Collectors(t *testing.T) {
}

func testHandler(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond)
w.WriteHeader(http.StatusOK)
}

func makeRequest(t *testing.T, r *chi.Mux, paths [][]string) string {
t.Helper()
rec := httptest.NewRecorder()
for _, p := range paths {
u, err := url.JoinPath(testHost, p...)
if err != nil {
t.Error(err)
}
req, err := http.NewRequest("GET", u, nil)
if err != nil {
t.Error(err)
}
r.ServeHTTP(rec, req)
}
return rec.Body.String()
}

func TestMiddleware_Handler(t *testing.T) {
tests := map[string]struct {
body string
want bool
}{
"request header": {chiprometheus.RequestsCollectorName, true},
"latency header": {chiprometheus.LatencyCollectorName, true},
"bob": {`chi_request_duration_milliseconds_count{code="OK",method="GET",path="/users/bob",service="test"} 1`, false},
"alice": {`chi_request_duration_milliseconds_count{code="OK",method="GET",path="/users/alice",service="test"} 1`, false},
"path variable": {`chi_request_duration_milliseconds_count{code="OK",method="GET",path="/users/{firstName}",service="test"} 2`, true},
}

r := chi.NewRouter()
m := chiprometheus.New("test")
m.MustRegisterDefault()
Expand All @@ -73,53 +102,37 @@ func TestMiddleware_Handler(t *testing.T) {
r.Handle("/metrics", promhttp.Handler())
r.Get("/healthz", testHandler)
r.Get("/users/{firstName}", testHandler)

paths := [][]string{
{"healthz"},
{"users", "bob"},
{"users", "alice"},
{"metrics"},
}
rec := httptest.NewRecorder()
for _, p := range paths {
u, err := url.JoinPath(testHost, p...)
if err != nil {
t.Error(err)
}
req, err := http.NewRequest("GET", u, nil)
if err != nil {
t.Error(err)
}
r.ServeHTTP(rec, req)
}
body := rec.Body.String()
got := makeRequest(t, r, paths)

if !strings.Contains(body, chiprometheus.RequestsCollectorName) {
t.Errorf("body should contain request total entry '%s'", chiprometheus.RequestsCollectorName)
}
if !strings.Contains(body, chiprometheus.LatencyCollectorName) {
t.Errorf("body should contain request duration entry '%s'", chiprometheus.LatencyCollectorName)
}

healthzCount := `chi_request_duration_milliseconds_count{code="OK",method="GET",path="/healthz",service="test"} 1`
bobCount := `chi_request_duration_milliseconds_count{code="OK",method="GET",path="/users/bob",service="test"} 1`
aliceCount := `chi_request_duration_milliseconds_count{code="OK",method="GET",path="/users/alice",service="test"} 1`
aggregatedCount := `chi_request_duration_milliseconds_count{code="OK",method="GET",path="/users/{firstName}",service="test"} 2`
if !strings.Contains(body, healthzCount) {
t.Errorf("body should contain healthz count summary '%s'", healthzCount)
}
if strings.Contains(body, bobCount) {
t.Errorf("body should NOT contain Bob count summary '%s'", bobCount)
}
if strings.Contains(body, aliceCount) {
t.Errorf("body should NOT contain Alice count summary '%s'", aliceCount)
}
if !strings.Contains(body, aggregatedCount) {
t.Errorf("body should contain first name count summary '%s'", aggregatedCount)
for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
if tt.want && !strings.Contains(got, tt.body) {
t.Fatalf("body should contain %s", tt.body)
} else if !tt.want && strings.Contains(got, tt.body) {
t.Fatalf("body should NOT contain %s", tt.body)
}
})
}
}

func TestMiddleware_HandlerWithCustomRegistry(t *testing.T) {
tests := map[string]struct {
want string
}{
"request header": {chiprometheus.RequestsCollectorName},
"latency header": {chiprometheus.LatencyCollectorName},
"bob": {"promhttp_metric_handler_requests_total"},
"alice": {"go_goroutines"},
}

r := chi.NewRouter()
reg := prometheus.NewRegistry()
if err := reg.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})); err != nil {
Expand All @@ -138,40 +151,94 @@ func TestMiddleware_HandlerWithCustomRegistry(t *testing.T) {
promh := promhttp.InstrumentMetricHandler(
reg, promhttp.HandlerFor(reg, promhttp.HandlerOpts{}),
)

r.Use(m.Handler)
r.Handle("/metrics", promh)
r.Get("/healthz", testHandler)

paths := [][]string{
{"healthz"},
{"metrics"},
}
rec := httptest.NewRecorder()
for _, p := range paths {
u, err := url.JoinPath(testHost, p...)
if err != nil {
t.Error(err)
got := makeRequest(t, r, paths)

for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
if !strings.Contains(got, tt.want) {
t.Fatalf("body should contain %s", tt.want)
}
})
}
}

func TestMiddleware_HandlerWithBucketEnv(t *testing.T) {
key := chiprometheus.EnvChiPrometheusLatencyBuckets

t.Run("with invalid env", func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Errorf("must have panicked")
}
}()
if err := os.Setenv(key, "invalid value"); err != nil {
t.Fatalf("failed to set %s", key)
}
req, err := http.NewRequest("GET", u, nil)
if err != nil {
t.Error(err)
t.Cleanup(func() { _ = os.Unsetenv(key) })
chiprometheus.New("test")
})

t.Run("with valid env", func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("must not have panicked")
}
}()

tests := map[string]struct {
body string
want bool
}{
"le 101": {`path="/healthz",service="test",le="101"`, true},
"le 201": {`path="/healthz",service="test",le="201"`, true},
"le +Inf": {`path="/healthz",service="test",le="+Inf"`, true},
// default values should be overwritten
"le 300": {`path="/healthz",service="test",le="300"`, false},
"le 1200": {`path="/healthz",service="test",le="1200"`, false},
"le 5000": {`path="/healthz",service="test",le="1200"`, false},
}
r.ServeHTTP(rec, req)
}
body := rec.Body.String()

if !strings.Contains(body, chiprometheus.RequestsCollectorName) {
t.Errorf("body should contain request total entry '%s'", chiprometheus.RequestsCollectorName)
}
if !strings.Contains(body, chiprometheus.LatencyCollectorName) {
t.Errorf("body should contain request duration entry '%s'", chiprometheus.LatencyCollectorName)
}
if err := os.Setenv(key, "101,201"); err != nil {
t.Fatalf("failed to set %s", key)
}
t.Cleanup(func() { _ = os.Unsetenv(key) })

if !strings.Contains(body, "promhttp_metric_handler_requests_total") {
t.Error("body should contain promhttp_metric_handler_requests_total from ProcessCollector")
}
if !strings.Contains(body, "go_goroutines") {
t.Errorf("body should contain Go runtime metrics from GoCollector")
}
r := chi.NewRouter()
m := chiprometheus.New("test")
m.MustRegisterDefault()
t.Cleanup(func() {
for _, c := range m.Collectors() {
prometheus.Unregister(c)
}
})
r.Use(m.Handler)
r.Handle("/metrics", promhttp.Handler())
r.Get("/healthz", testHandler)
paths := [][]string{
{"healthz"},
{"metrics"},
}
got := makeRequest(t, r, paths)

for name, tt := range tests {
tt := tt
t.Run(name, func(t *testing.T) {
t.Parallel()
if tt.want && !strings.Contains(got, tt.body) {
t.Fatalf("body should contain %s", tt.body)
} else if !tt.want && strings.Contains(got, tt.body) {
t.Fatalf("body should NOT contain %s", tt.body)
}
})
}
})
}

0 comments on commit ca1265c

Please sign in to comment.