@@ -3,16 +3,18 @@ import { validServiceConfig, validTeamResponse } from "../../mocks.js";
3
3
import { rateLimit } from "./index.js" ;
4
4
5
5
const mockRedis = {
6
- incr : vi . fn ( ) ,
6
+ get : vi . fn ( ) ,
7
7
expire : vi . fn ( ) ,
8
+ incrBy : vi . fn ( ) ,
8
9
} ;
9
10
10
11
describe ( "rateLimit" , ( ) => {
11
12
beforeEach ( ( ) => {
12
13
// Clear mock function calls and reset any necessary state.
13
14
vi . clearAllMocks ( ) ;
14
- mockRedis . incr . mockReset ( ) ;
15
+ mockRedis . get . mockReset ( ) ;
15
16
mockRedis . expire . mockReset ( ) ;
17
+ mockRedis . incrBy . mockReset ( ) ;
16
18
} ) ;
17
19
18
20
afterEach ( ( ) => {
@@ -35,7 +37,7 @@ describe("rateLimit", () => {
35
37
} ) ;
36
38
37
39
it ( "should not rate limit if within limit" , async ( ) => {
38
- mockRedis . incr . mockResolvedValue ( 50 ) ; // Current count is 50 requests in 10 seconds.
40
+ mockRedis . get . mockResolvedValue ( "50" ) ; // Current count is 50 requests in 10 seconds.
39
41
40
42
const result = await rateLimit ( {
41
43
team : validTeamResponse ,
@@ -46,15 +48,15 @@ describe("rateLimit", () => {
46
48
47
49
expect ( result ) . toEqual ( {
48
50
rateLimited : false ,
49
- requestCount : 50 ,
51
+ requestCount : 51 ,
50
52
rateLimit : 50 ,
51
53
} ) ;
52
54
53
- expect ( mockRedis . expire ) . not . toHaveBeenCalled ( ) ;
55
+ expect ( mockRedis . incrBy ) . toHaveBeenCalledTimes ( 1 ) ;
54
56
} ) ;
55
57
56
58
it ( "should rate limit if exceeded hard limit" , async ( ) => {
57
- mockRedis . incr . mockResolvedValue ( 51 ) ;
59
+ mockRedis . get . mockResolvedValue ( 51 ) ;
58
60
59
61
const result = await rateLimit ( {
60
62
team : validTeamResponse ,
@@ -72,11 +74,11 @@ describe("rateLimit", () => {
72
74
errorCode : "RATE_LIMIT_EXCEEDED" ,
73
75
} ) ;
74
76
75
- expect ( mockRedis . expire ) . not . toHaveBeenCalled ( ) ;
77
+ expect ( mockRedis . incrBy ) . not . toHaveBeenCalled ( ) ;
76
78
} ) ;
77
79
78
80
it ( "expires on the first incr request only" , async ( ) => {
79
- mockRedis . incr . mockResolvedValue ( 1 ) ;
81
+ mockRedis . get . mockResolvedValue ( "1" ) ;
80
82
81
83
const result = await rateLimit ( {
82
84
team : validTeamResponse ,
@@ -87,14 +89,14 @@ describe("rateLimit", () => {
87
89
88
90
expect ( result ) . toEqual ( {
89
91
rateLimited : false ,
90
- requestCount : 1 ,
92
+ requestCount : 2 ,
91
93
rateLimit : 50 ,
92
94
} ) ;
93
- expect ( mockRedis . expire ) . toHaveBeenCalled ( ) ;
95
+ expect ( mockRedis . incrBy ) . toHaveBeenCalled ( ) ;
94
96
} ) ;
95
97
96
98
it ( "enforces rate limit if sampled (hit)" , async ( ) => {
97
- mockRedis . incr . mockResolvedValue ( 10 ) ;
99
+ mockRedis . get . mockResolvedValue ( "10" ) ;
98
100
vi . spyOn ( global . Math , "random" ) . mockReturnValue ( 0.08 ) ;
99
101
100
102
const result = await rateLimit ( {
@@ -117,7 +119,7 @@ describe("rateLimit", () => {
117
119
} ) ;
118
120
119
121
it ( "does not enforce rate limit if sampled (miss)" , async ( ) => {
120
- mockRedis . incr . mockResolvedValue ( 10 ) ;
122
+ mockRedis . get . mockResolvedValue ( 10 ) ;
121
123
vi . spyOn ( global . Math , "random" ) . mockReturnValue ( 0.15 ) ;
122
124
123
125
const result = await rateLimit ( {
@@ -134,4 +136,152 @@ describe("rateLimit", () => {
134
136
rateLimit : 0 ,
135
137
} ) ;
136
138
} ) ;
139
+
140
+ it ( "should handle redis get failure gracefully" , async ( ) => {
141
+ mockRedis . get . mockRejectedValue ( new Error ( "Redis connection error" ) ) ;
142
+
143
+ const result = await rateLimit ( {
144
+ team : validTeamResponse ,
145
+ limitPerSecond : 5 ,
146
+ serviceConfig : validServiceConfig ,
147
+ redis : mockRedis ,
148
+ } ) ;
149
+
150
+ expect ( result ) . toEqual ( {
151
+ rateLimited : false ,
152
+ requestCount : 1 ,
153
+ rateLimit : 50 ,
154
+ } ) ;
155
+ } ) ;
156
+
157
+ it ( "should handle zero requests correctly" , async ( ) => {
158
+ mockRedis . get . mockResolvedValue ( "0" ) ;
159
+
160
+ const result = await rateLimit ( {
161
+ team : validTeamResponse ,
162
+ limitPerSecond : 5 ,
163
+ serviceConfig : validServiceConfig ,
164
+ redis : mockRedis ,
165
+ } ) ;
166
+
167
+ expect ( result ) . toEqual ( {
168
+ rateLimited : false ,
169
+ requestCount : 1 ,
170
+ rateLimit : 50 ,
171
+ } ) ;
172
+ expect ( mockRedis . incrBy ) . toHaveBeenCalledWith ( expect . any ( String ) , 1 ) ;
173
+ } ) ;
174
+
175
+ it ( "should handle null response from redis" , async ( ) => {
176
+ mockRedis . get . mockResolvedValue ( null ) ;
177
+
178
+ const result = await rateLimit ( {
179
+ team : validTeamResponse ,
180
+ limitPerSecond : 5 ,
181
+ serviceConfig : validServiceConfig ,
182
+ redis : mockRedis ,
183
+ } ) ;
184
+
185
+ expect ( result ) . toEqual ( {
186
+ rateLimited : false ,
187
+ requestCount : 1 ,
188
+ rateLimit : 50 ,
189
+ } ) ;
190
+ } ) ;
191
+
192
+ it ( "should handle very low sample rates" , async ( ) => {
193
+ mockRedis . get . mockResolvedValue ( "100" ) ;
194
+ vi . spyOn ( global . Math , "random" ) . mockReturnValue ( 0.001 ) ;
195
+
196
+ const result = await rateLimit ( {
197
+ team : validTeamResponse ,
198
+ limitPerSecond : 5 ,
199
+ serviceConfig : validServiceConfig ,
200
+ redis : mockRedis ,
201
+ sampleRate : 0.01 ,
202
+ } ) ;
203
+
204
+ expect ( result ) . toEqual ( {
205
+ rateLimited : true ,
206
+ requestCount : 100 ,
207
+ rateLimit : 0.5 ,
208
+ status : 429 ,
209
+ errorMessage : expect . any ( String ) ,
210
+ errorCode : "RATE_LIMIT_EXCEEDED" ,
211
+ } ) ;
212
+ } ) ;
213
+
214
+ it ( "should handle multiple concurrent requests with redis lag" , async ( ) => {
215
+ // Mock initial state
216
+ mockRedis . get . mockResolvedValue ( "0" ) ;
217
+
218
+ // Mock redis.set to have 100ms delay
219
+ mockRedis . incrBy . mockImplementation (
220
+ ( ) =>
221
+ new Promise ( ( resolve ) => {
222
+ setTimeout ( ( ) => resolve ( 1 ) , 100 ) ;
223
+ } ) ,
224
+ ) ;
225
+
226
+ // Make 3 concurrent requests
227
+ const requests = Promise . all ( [
228
+ rateLimit ( {
229
+ team : validTeamResponse ,
230
+ limitPerSecond : 5 ,
231
+ serviceConfig : validServiceConfig ,
232
+ redis : mockRedis ,
233
+ } ) ,
234
+ rateLimit ( {
235
+ team : validTeamResponse ,
236
+ limitPerSecond : 5 ,
237
+ serviceConfig : validServiceConfig ,
238
+ redis : mockRedis ,
239
+ } ) ,
240
+ rateLimit ( {
241
+ team : validTeamResponse ,
242
+ limitPerSecond : 5 ,
243
+ serviceConfig : validServiceConfig ,
244
+ redis : mockRedis ,
245
+ } ) ,
246
+ ] ) ;
247
+
248
+ const results = await requests ;
249
+ // All requests should succeed since they all see initial count of 0
250
+ for ( const result of results ) {
251
+ expect ( result ) . toEqual ( {
252
+ rateLimited : false ,
253
+ requestCount : 1 ,
254
+ rateLimit : 50 ,
255
+ } ) ;
256
+ }
257
+
258
+ // Redis set should be called 3 times
259
+ expect ( mockRedis . incrBy ) . toHaveBeenCalledTimes ( 3 ) ;
260
+ } ) ;
261
+
262
+ it ( "should handle custom increment values" , async ( ) => {
263
+ // Mock initial state
264
+ mockRedis . get . mockResolvedValue ( "5" ) ;
265
+ mockRedis . incrBy . mockResolvedValue ( 10 ) ;
266
+
267
+ const result = await rateLimit ( {
268
+ team : validTeamResponse ,
269
+ limitPerSecond : 20 ,
270
+ serviceConfig : validServiceConfig ,
271
+ redis : mockRedis ,
272
+ increment : 5 ,
273
+ } ) ;
274
+
275
+ expect ( result ) . toEqual ( {
276
+ rateLimited : false ,
277
+ requestCount : 10 ,
278
+ rateLimit : 200 ,
279
+ } ) ;
280
+
281
+ // Verify redis was called with correct increment
282
+ expect ( mockRedis . incrBy ) . toHaveBeenCalledWith (
283
+ expect . stringContaining ( "rate-limit" ) ,
284
+ 5 ,
285
+ ) ;
286
+ } ) ;
137
287
} ) ;
0 commit comments