-
Notifications
You must be signed in to change notification settings - Fork 1
/
rl.go
223 lines (202 loc) · 6.17 KB
/
rl.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
package rl
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"sync"
"time"
"github.com/2manymws/rl/counter"
"golang.org/x/sync/errgroup"
)
// Rule is a rate limit rule
type Rule struct {
// Key for the rate limit
Key string
// ReqLimit is the request limit for the window
// If ReqLimit is negative, the limiter is skipped
ReqLimit int
// WindowLen is the length of the window
WindowLen time.Duration
// IgnoreAfter is true if skip all limiters after this limiter
IgnoreAfter bool
}
type Limiter interface {
// Name returns the name of the limiter
Name() string
// Rule returns the key and rate limit rule for the request
Rule(r *http.Request) (rule *Rule, err error)
// ShouldSetXRateLimitHeaders returns true if the X-RateLimit-* headers should be set
ShouldSetXRateLimitHeaders(*Context) bool
// OnRequestLimit returns the handler to be called when the rate limit is exceeded
OnRequestLimit(*Context) http.HandlerFunc
}
type Counter interface {
// Get returns the current count for the key and window
Get(key string, window time.Time) (count int, err error) //nostyle:getters
// Increment increments the count for the key and window
Increment(key string, currWindow time.Time) error
}
var _ Counter = (*counter.Counter)(nil)
type limiter struct {
Limiter
Get func(key string, window time.Time) (count int, err error) //nostyle:getters
Increment func(key string, currWindow time.Time) error
}
func newLimiter(l Limiter) *limiter {
const defaultTTL = 1 * time.Hour
ll := &limiter{
Limiter: l,
}
if c, ok := l.(Counter); ok {
ll.Get = c.Get
ll.Increment = c.Increment
} else {
ttl := defaultTTL
r, err := ll.Rule(&http.Request{})
if err == nil && r.WindowLen > 0 {
ttl = r.WindowLen * 2
}
cc := counter.New(ttl)
ll.Get = cc.Get
ll.Increment = cc.Increment
}
return ll
}
type limitHandler struct {
key string
reqLimit int
windowLen time.Duration
limiter *limiter
rateLimitRemaining int
rateLimitReset int
mu sync.Mutex
}
func (lh *limitHandler) status(now, currWindow time.Time) (float64, error) {
prevWindow := currWindow.Add(-lh.windowLen)
currCount, err := lh.limiter.Get(lh.key, currWindow)
if err != nil {
return 0, err
}
prevCount, err := lh.limiter.Get(lh.key, prevWindow)
if err != nil {
return 0, err
}
diff := now.Sub(currWindow)
rate := float64(prevCount)*(float64(lh.windowLen)-float64(diff))/float64(lh.windowLen) + float64(currCount)
return rate, nil
}
type limitMw struct {
limiters []*limiter
}
func newLimitMw(limiters []Limiter) *limitMw {
var ls []*limiter
for _, l := range limiters {
ls = append(ls, newLimiter(l))
}
return &limitMw{
limiters: ls,
}
}
func (lm *limitMw) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
now := time.Now().UTC()
var lastLH *limitHandler
eg, ctx := errgroup.WithContext(context.Background())
for _, l := range lm.limiters {
rule, err := l.Rule(r)
if err != nil {
http.Error(w, err.Error(), http.StatusPreconditionRequired)
return
}
if rule.ReqLimit < 0 {
// If the request limit is negative, skip this limiter
if rule.IgnoreAfter {
// Skip all limiters after this limiter.
break
}
continue
}
lh := &limitHandler{
key: rule.Key,
reqLimit: rule.ReqLimit,
windowLen: rule.WindowLen,
limiter: l,
}
lastLH = lh
eg.Go(func() error {
lh.mu.Lock()
defer lh.mu.Unlock()
currWindow := now.Truncate(lh.windowLen)
lh.rateLimitRemaining = 0
lh.rateLimitReset = int(currWindow.Add(lh.windowLen).Unix())
select {
// Check if the request limit already exceeded before calling lh.status()
case <-ctx.Done():
// Increment must be called even if the request limit is already exceeded
if err := lh.limiter.Increment(lh.key, currWindow); err != nil {
return newContext(http.StatusInternalServerError, err, lh, next)
}
return nil
default:
}
rate, err := lh.status(now, currWindow)
if err != nil {
return newContext(http.StatusPreconditionRequired, err, lh, next)
}
nrate := int(math.Round(rate))
if nrate >= lh.reqLimit {
return newContext(http.StatusTooManyRequests, ErrRateLimitExceeded, lh, next)
}
lh.rateLimitRemaining = lh.reqLimit - nrate
if err := lh.limiter.Increment(lh.key, currWindow); err != nil {
return newContext(http.StatusInternalServerError, err, lh, next)
}
return nil
})
if rule.IgnoreAfter {
// Skip all limiters after this limiter.
break
}
}
// Wait for all limiters to finish
if err := eg.Wait(); err != nil {
// Handle first error
if e, ok := err.(*Context); ok {
if e.lh.limiter.ShouldSetXRateLimitHeaders(e) {
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", e.lh.reqLimit))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", e.lh.rateLimitRemaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", e.lh.rateLimitReset))
}
if errors.Is(e.Err, ErrRateLimitExceeded) {
// Rate limit exceeded
if e.lh.limiter.ShouldSetXRateLimitHeaders(e) {
w.Header().Set("Retry-After", fmt.Sprintf("%d", int(e.lh.windowLen.Seconds()))) // RFC 6585
}
e.lh.limiter.OnRequestLimit(e)(w, r)
return
}
http.Error(w, e.Error(), e.StatusCode)
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if lastLH != nil {
// Set X-RateLimit-* headers using the last limiter
if lastLH.limiter.ShouldSetXRateLimitHeaders(nil) && lastLH.reqLimit >= 0 {
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", lastLH.reqLimit))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", lastLH.rateLimitRemaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", lastLH.rateLimitReset))
}
}
next.ServeHTTP(w, r)
})
}
// New returns a new rate limiter middleware.
// The order of the limitters should be arranged in **reverse** order of Limitter with strict rate limit to return appropriate X-RateLimit-* headers to the client.
func New(limiters ...Limiter) func(next http.Handler) http.Handler {
rl := newLimitMw(limiters)
return rl.Handler
}