Skip to content

Commit 80029e2

Browse files
authored
Implement rate-limiting from HTTP handler (e.g. by request payload) (#42)
1 parent 99b3b69 commit 80029e2

File tree

4 files changed

+176
-72
lines changed

4 files changed

+176
-72
lines changed

README.md

+36-8
Original file line numberDiff line numberDiff line change
@@ -78,36 +78,64 @@ r.Use(httprate.Limit(
7878
))
7979
```
8080

81-
### Send specific response for rate limited requests
81+
### Rate limit by request payload
82+
```go
83+
// Rate-limiter for login endpoint.
84+
loginRateLimiter := httprate.NewRateLimiter(5, time.Minute)
85+
86+
r.Post("/login", func(w http.ResponseWriter, r *http.Request) {
87+
var payload struct {
88+
Username string `json:"username"`
89+
Password string `json:"password"`
90+
}
91+
err := json.NewDecoder(r.Body).Decode(&payload)
92+
if err != nil || payload.Username == "" || payload.Password == "" {
93+
w.WriteHeader(400)
94+
return
95+
}
96+
97+
// Rate-limit login at 5 req/min.
98+
if loginRateLimiter.OnLimit(w, r, payload.Username) {
99+
return
100+
}
101+
102+
w.Write([]byte("login at 5 req/min\n"))
103+
})
104+
```
105+
106+
### Send specific response for rate-limited requests
107+
108+
The default response is `HTTP 429` with `Too Many Requests` body. You can override it with:
82109

83110
```go
84111
r.Use(httprate.Limit(
85112
10,
86113
time.Minute,
87114
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
88-
http.Error(w, `{"error": "Rate limited. Please slow down."}`, http.StatusTooManyRequests)
115+
http.Error(w, `{"error": "Rate-limited. Please, slow down."}`, http.StatusTooManyRequests)
89116
}),
90117
))
91118
```
92119

93-
### Send specific response for backend errors
120+
### Send specific response on errors
121+
122+
An error can be returned by:
123+
- A custom key function provided by `httprate.WithKeyFunc(customKeyFn)`
124+
- A custom backend provided by `httprateredis.WithRedisLimitCounter(customBackend)`
125+
- The default local in-memory counter is guaranteed not return any errors
126+
- Backends that fall-back to the local in-memory counter (e.g. [httprate-redis](https://github.com/go-chi/httprate-redis)) can choose not to return any errors either
94127

95128
```go
96129
r.Use(httprate.Limit(
97130
10,
98131
time.Minute,
99132
httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) {
100-
// NOTE: The local in-memory counter is guaranteed not return any errors.
101-
// Other backends may return errors, depending on whether they have
102-
// in-memory fallback mechanism implemented in case of network errors.
103-
104133
http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired)
105134
}),
106135
httprate.WithLimitCounter(customBackend),
107136
))
108137
```
109138

110-
111139
### Send custom response headers
112140

113141
```go

_example/main.go

+32-24
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"encoding/json"
56
"log"
67
"net/http"
78
"time"
@@ -15,52 +16,59 @@ func main() {
1516
r := chi.NewRouter()
1617
r.Use(middleware.Logger)
1718

19+
// Rate-limit all routes at 1000 req/min by IP address.
20+
r.Use(httprate.LimitByIP(1000, time.Minute))
21+
1822
r.Route("/admin", func(r chi.Router) {
1923
r.Use(func(next http.Handler) http.Handler {
2024
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21-
// Note: this is a mock middleware to set a userID on the request context
25+
// Note: This is a mock middleware to set a userID on the request context
2226
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "userID", "123")))
2327
})
2428
})
2529

26-
// Here we set a specific rate limit by ip address and userID
30+
// Rate-limit admin routes at 10 req/s by userID.
2731
r.Use(httprate.Limit(
28-
10,
29-
time.Minute,
30-
httprate.WithKeyFuncs(httprate.KeyByIP, func(r *http.Request) (string, error) {
31-
token := r.Context().Value("userID").(string)
32+
10, time.Second,
33+
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
34+
token, _ := r.Context().Value("userID").(string)
3235
return token, nil
3336
}),
34-
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
35-
// We can send custom responses for the rate limited requests, e.g. a JSON message
36-
w.Header().Set("Content-Type", "application/json")
37-
w.WriteHeader(http.StatusTooManyRequests)
38-
w.Write([]byte(`{"error": "Too many requests"}`))
39-
}),
4037
))
4138

4239
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
43-
w.Write([]byte("10 req/min\n"))
40+
w.Write([]byte("admin at 10 req/s\n"))
4441
})
4542
})
4643

47-
r.Group(func(r chi.Router) {
48-
// Here we set another rate limit (3 req/min) for a group of handlers.
49-
//
50-
// Note: in practice you don't need to have so many layered rate-limiters,
51-
// but the example here is to illustrate how to control the machinery.
52-
r.Use(httprate.LimitByIP(3, time.Minute))
44+
// Rate-limiter for login endpoint.
45+
loginRateLimiter := httprate.NewRateLimiter(5, time.Minute)
5346

54-
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
55-
w.Write([]byte("3 req/min\n"))
56-
})
47+
r.Post("/login", func(w http.ResponseWriter, r *http.Request) {
48+
var payload struct {
49+
Username string `json:"username"`
50+
Password string `json:"password"`
51+
}
52+
err := json.NewDecoder(r.Body).Decode(&payload)
53+
if err != nil || payload.Username == "" || payload.Password == "" {
54+
w.WriteHeader(400)
55+
return
56+
}
57+
58+
// Rate-limit login at 5 req/min.
59+
if loginRateLimiter.OnLimit(w, r, payload.Username) {
60+
return
61+
}
62+
63+
w.Write([]byte("login at 5 req/min\n"))
5764
})
5865

5966
log.Printf("Serving at localhost:3333")
6067
log.Println()
6168
log.Printf("Try running:")
62-
log.Printf("curl -v http://localhost:3333")
63-
log.Printf("curl -v http://localhost:3333/admin")
69+
log.Printf(`curl -v http://localhost:3333?[0-1000]`)
70+
log.Printf(`curl -v http://localhost:3333/admin?[1-12]`)
71+
log.Printf(`curl -v http://localhost:3333/login\?[1-8] --data '{"username":"alice","password":"***"}'`)
6472

6573
http.ListenAndServe(":3333", r)
6674
}

limiter.go

+51-40
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,56 @@ type rateLimiter struct {
6666
mu sync.Mutex
6767
}
6868

69+
// OnLimit checks the rate limit for the given key. If the limit is reached, it returns true
70+
// and automatically sends HTTP response. The caller should halt further request processing.
71+
// If the limit is not reached, it increments the request count and returns false, allowing
72+
// the request to proceed.
73+
func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
74+
currentWindow := time.Now().UTC().Truncate(l.windowLength)
75+
ctx := r.Context()
76+
77+
limit := l.requestLimit
78+
if val := getRequestLimit(ctx); val > 0 {
79+
limit = val
80+
}
81+
setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit))
82+
setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))
83+
84+
l.mu.Lock()
85+
_, rateFloat, err := l.calculateRate(key, limit)
86+
if err != nil {
87+
l.mu.Unlock()
88+
l.onError(w, r, err)
89+
return true
90+
}
91+
rate := int(math.Round(rateFloat))
92+
93+
increment := getIncrement(r.Context())
94+
if increment > 1 {
95+
setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment))
96+
}
97+
98+
if rate+increment > limit {
99+
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate))
100+
101+
l.mu.Unlock()
102+
setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
103+
l.onRateLimited(w, r)
104+
return true
105+
}
106+
107+
err = l.limitCounter.IncrementBy(key, currentWindow, increment)
108+
if err != nil {
109+
l.mu.Unlock()
110+
l.onError(w, r, err)
111+
return true
112+
}
113+
l.mu.Unlock()
114+
115+
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment))
116+
return false
117+
}
118+
69119
func (l *rateLimiter) Counter() LimitCounter {
70120
return l.limitCounter
71121
}
@@ -82,49 +132,10 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
82132
return
83133
}
84134

85-
currentWindow := time.Now().UTC().Truncate(l.windowLength)
86-
ctx := r.Context()
87-
88-
limit := l.requestLimit
89-
if val := getRequestLimit(ctx); val > 0 {
90-
limit = val
91-
}
92-
setHeader(w, l.headers.Limit, fmt.Sprintf("%d", limit))
93-
setHeader(w, l.headers.Reset, fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))
94-
95-
l.mu.Lock()
96-
_, rateFloat, err := l.calculateRate(key, limit)
97-
if err != nil {
98-
l.mu.Unlock()
99-
l.onError(w, r, err)
100-
return
101-
}
102-
rate := int(math.Round(rateFloat))
103-
104-
increment := getIncrement(r.Context())
105-
if increment > 1 {
106-
setHeader(w, l.headers.Increment, fmt.Sprintf("%d", increment))
107-
}
108-
109-
if rate+increment > limit {
110-
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate))
111-
112-
l.mu.Unlock()
113-
setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
114-
l.onRateLimited(w, r)
135+
if l.OnLimit(w, r, key) {
115136
return
116137
}
117138

118-
err = l.limitCounter.IncrementBy(key, currentWindow, increment)
119-
if err != nil {
120-
l.mu.Unlock()
121-
l.onError(w, r, err)
122-
return
123-
}
124-
l.mu.Unlock()
125-
126-
setHeader(w, l.headers.Remaining, fmt.Sprintf("%d", limit-rate-increment))
127-
128139
next.ServeHTTP(w, r)
129140
})
130141
}

limiter_test.go

+57
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package httprate_test
33
import (
44
"bytes"
55
"context"
6+
"encoding/json"
67
"io"
78
"net/http"
89
"net/http/httptest"
@@ -437,3 +438,59 @@ func TestOverrideRequestLimit(t *testing.T) {
437438
}
438439
}
439440
}
441+
442+
func TestRateLimitPayload(t *testing.T) {
443+
loginRateLimiter := httprate.NewRateLimiter(5, time.Minute)
444+
445+
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
446+
var payload struct {
447+
Username string `json:"username"`
448+
Password string `json:"password"`
449+
}
450+
err := json.NewDecoder(r.Body).Decode(&payload)
451+
if err != nil || payload.Username == "" || payload.Password == "" {
452+
w.WriteHeader(400)
453+
return
454+
}
455+
456+
// Rate-limit login at 5 req/min.
457+
if loginRateLimiter.OnLimit(w, r, payload.Username) {
458+
return
459+
}
460+
461+
w.Write([]byte("login at 5 req/min\n"))
462+
})
463+
464+
responses := []struct {
465+
StatusCode int
466+
Body string
467+
}{
468+
{StatusCode: 200, Body: "login at 5 req/min"},
469+
{StatusCode: 200, Body: "login at 5 req/min"},
470+
{StatusCode: 200, Body: "login at 5 req/min"},
471+
{StatusCode: 200, Body: "login at 5 req/min"},
472+
{StatusCode: 200, Body: "login at 5 req/min"},
473+
{StatusCode: 429, Body: "Too Many Requests"},
474+
{StatusCode: 429, Body: "Too Many Requests"},
475+
{StatusCode: 429, Body: "Too Many Requests"},
476+
}
477+
for i, response := range responses {
478+
req, err := http.NewRequest("GET", "/", strings.NewReader(`{"username":"alice","password":"***"}`))
479+
if err != nil {
480+
t.Errorf("failed = %v", err)
481+
}
482+
483+
recorder := httptest.NewRecorder()
484+
h.ServeHTTP(recorder, req)
485+
result := recorder.Result()
486+
if respStatus := result.StatusCode; respStatus != response.StatusCode {
487+
t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respStatus, response.StatusCode)
488+
}
489+
body, _ := io.ReadAll(result.Body)
490+
respBody := strings.TrimSuffix(string(body), "\n")
491+
492+
if string(respBody) != response.Body {
493+
t.Errorf("resp.Body(%v) = %q, want %q", i, respBody, response.Body)
494+
}
495+
}
496+
}

0 commit comments

Comments
 (0)