Skip to content

Commit c4c778c

Browse files
authored
Export RateLimiter type (#43)
So users pass *http.RateLimiter (or save in their server struct) and use the new .OnLimit() feature from #42.
1 parent 80029e2 commit c4c778c

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

httprate.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func Limit(requestLimit int, windowLength time.Duration, options ...Option) func
1212
}
1313

1414
type KeyFunc func(r *http.Request) (string, error)
15-
type Option func(rl *rateLimiter)
15+
type Option func(rl *RateLimiter)
1616

1717
// Set custom response headers. If empty, the header is omitted.
1818
type ResponseHeaders struct {
@@ -72,7 +72,7 @@ func KeyByEndpoint(r *http.Request) (string, error) {
7272
}
7373

7474
func WithKeyFuncs(keyFuncs ...KeyFunc) Option {
75-
return func(rl *rateLimiter) {
75+
return func(rl *RateLimiter) {
7676
if len(keyFuncs) > 0 {
7777
rl.keyFn = composedKeyFunc(keyFuncs...)
7878
}
@@ -88,31 +88,31 @@ func WithKeyByRealIP() Option {
8888
}
8989

9090
func WithLimitHandler(h http.HandlerFunc) Option {
91-
return func(rl *rateLimiter) {
91+
return func(rl *RateLimiter) {
9292
rl.onRateLimited = h
9393
}
9494
}
9595

9696
func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option {
97-
return func(rl *rateLimiter) {
97+
return func(rl *RateLimiter) {
9898
rl.onError = h
9999
}
100100
}
101101

102102
func WithLimitCounter(c LimitCounter) Option {
103-
return func(rl *rateLimiter) {
103+
return func(rl *RateLimiter) {
104104
rl.limitCounter = c
105105
}
106106
}
107107

108108
func WithResponseHeaders(headers ResponseHeaders) Option {
109-
return func(rl *rateLimiter) {
109+
return func(rl *RateLimiter) {
110110
rl.headers = headers
111111
}
112112
}
113113

114114
func WithNoop() Option {
115-
return func(rl *rateLimiter) {}
115+
return func(rl *RateLimiter) {}
116116
}
117117

118118
func composedKeyFunc(keyFuncs ...KeyFunc) KeyFunc {

limiter.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ type LimitCounter interface {
1515
Get(key string, currentWindow, previousWindow time.Time) (int, int, error)
1616
}
1717

18-
func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *rateLimiter {
19-
rl := &rateLimiter{
18+
func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Option) *RateLimiter {
19+
rl := &RateLimiter{
2020
requestLimit: requestLimit,
2121
windowLength: windowLength,
2222
headers: ResponseHeaders{
@@ -55,7 +55,7 @@ func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt
5555
return rl
5656
}
5757

58-
type rateLimiter struct {
58+
type RateLimiter struct {
5959
requestLimit int
6060
windowLength time.Duration
6161
keyFn KeyFunc
@@ -70,7 +70,7 @@ type rateLimiter struct {
7070
// and automatically sends HTTP response. The caller should halt further request processing.
7171
// If the limit is not reached, it increments the request count and returns false, allowing
7272
// the request to proceed.
73-
func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
73+
func (l *RateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string) bool {
7474
currentWindow := time.Now().UTC().Truncate(l.windowLength)
7575
ctx := r.Context()
7676

@@ -116,15 +116,15 @@ func (l *rateLimiter) OnLimit(w http.ResponseWriter, r *http.Request, key string
116116
return false
117117
}
118118

119-
func (l *rateLimiter) Counter() LimitCounter {
119+
func (l *RateLimiter) Counter() LimitCounter {
120120
return l.limitCounter
121121
}
122122

123-
func (l *rateLimiter) Status(key string) (bool, float64, error) {
123+
func (l *RateLimiter) Status(key string) (bool, float64, error) {
124124
return l.calculateRate(key, l.requestLimit)
125125
}
126126

127-
func (l *rateLimiter) Handler(next http.Handler) http.Handler {
127+
func (l *RateLimiter) Handler(next http.Handler) http.Handler {
128128
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
129129
key, err := l.keyFn(r)
130130
if err != nil {
@@ -140,7 +140,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
140140
})
141141
}
142142

143-
func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
143+
func (l *RateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
144144
now := time.Now().UTC()
145145
currentWindow := now.Truncate(l.windowLength)
146146
previousWindow := currentWindow.Add(-l.windowLength)

0 commit comments

Comments
 (0)