Skip to content

Commit d036b87

Browse files
authored
Refactor rate limiting to use separate read/write Redis operations (#6702)
1 parent c86cac4 commit d036b87

File tree

3 files changed

+198
-22
lines changed

3 files changed

+198
-22
lines changed

.changeset/brown-moles-peel.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"@thirdweb-dev/service-utils": minor
3+
---
4+
5+
update rateLimit function

packages/service-utils/src/core/rateLimit/index.ts

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ const RATE_LIMIT_WINDOW_SECONDS = 10;
55

66
// Redis interface compatible with ioredis (Node) and upstash (Cloudflare Workers).
77
type IRedis = {
8-
incr: (key: string) => Promise<number>;
9-
expire: (key: string, ttlSeconds: number) => Promise<0 | 1>;
8+
get: (key: string) => Promise<string | null>;
9+
expire(key: string, seconds: number): Promise<number>;
10+
incrBy(key: string, value: number): Promise<number>;
1011
};
1112

1213
export async function rateLimit(args: {
@@ -20,8 +21,20 @@ export async function rateLimit(args: {
2021
* @default 1.0
2122
*/
2223
sampleRate?: number;
24+
/**
25+
* The number of requests to increment by.
26+
* @default 1
27+
*/
28+
increment?: number;
2329
}): Promise<RateLimitResult> {
24-
const { team, limitPerSecond, serviceConfig, redis, sampleRate = 1.0 } = args;
30+
const {
31+
team,
32+
limitPerSecond,
33+
serviceConfig,
34+
redis,
35+
sampleRate = 1.0,
36+
increment = 1,
37+
} = args;
2538

2639
const shouldSampleRequest = Math.random() < sampleRate;
2740
if (!shouldSampleRequest) {
@@ -49,12 +62,8 @@ export async function rateLimit(args: {
4962
RATE_LIMIT_WINDOW_SECONDS;
5063
const key = `rate-limit:${serviceScope}:${team.id}:${timestampWindow}`;
5164

52-
// Increment and get the current request count in this window.
53-
const requestCount = await redis.incr(key);
54-
if (requestCount === 1) {
55-
// For the first increment, set an expiration to clean up this key.
56-
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
57-
}
65+
// first read the request count from redis
66+
const requestCount = Number((await redis.get(key).catch(() => "0")) || "0");
5867

5968
// Get the limit for this window accounting for the sample rate.
6069
const limitPerWindow =
@@ -71,9 +80,21 @@ export async function rateLimit(args: {
7180
};
7281
}
7382

83+
// do not await this, it just needs to execute at all
84+
(async () =>
85+
// always incrementBy the amount specified for the key
86+
await redis.incrBy(key, increment).then(async () => {
87+
// if the initial request count was 0, set the key to expire in the future
88+
if (requestCount === 0) {
89+
await redis.expire(key, RATE_LIMIT_WINDOW_SECONDS);
90+
}
91+
}))().catch(() => {
92+
console.error("Error incrementing rate limit key", key);
93+
});
94+
7495
return {
7596
rateLimited: false,
76-
requestCount,
97+
requestCount: requestCount + increment,
7798
rateLimit: limitPerWindow,
7899
};
79100
}

packages/service-utils/src/core/rateLimit/rateLimit.test.ts

Lines changed: 162 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,18 @@ import { validServiceConfig, validTeamResponse } from "../../mocks.js";
33
import { rateLimit } from "./index.js";
44

55
const mockRedis = {
6-
incr: vi.fn(),
6+
get: vi.fn(),
77
expire: vi.fn(),
8+
incrBy: vi.fn(),
89
};
910

1011
describe("rateLimit", () => {
1112
beforeEach(() => {
1213
// Clear mock function calls and reset any necessary state.
1314
vi.clearAllMocks();
14-
mockRedis.incr.mockReset();
15+
mockRedis.get.mockReset();
1516
mockRedis.expire.mockReset();
17+
mockRedis.incrBy.mockReset();
1618
});
1719

1820
afterEach(() => {
@@ -35,7 +37,7 @@ describe("rateLimit", () => {
3537
});
3638

3739
it("should not rate limit if within limit", async () => {
38-
mockRedis.incr.mockResolvedValue(50); // Current count is 50 requests in 10 seconds.
40+
mockRedis.get.mockResolvedValue("50"); // Current count is 50 requests in 10 seconds.
3941

4042
const result = await rateLimit({
4143
team: validTeamResponse,
@@ -46,15 +48,15 @@ describe("rateLimit", () => {
4648

4749
expect(result).toEqual({
4850
rateLimited: false,
49-
requestCount: 50,
51+
requestCount: 51,
5052
rateLimit: 50,
5153
});
5254

53-
expect(mockRedis.expire).not.toHaveBeenCalled();
55+
expect(mockRedis.incrBy).toHaveBeenCalledTimes(1);
5456
});
5557

5658
it("should rate limit if exceeded hard limit", async () => {
57-
mockRedis.incr.mockResolvedValue(51);
59+
mockRedis.get.mockResolvedValue(51);
5860

5961
const result = await rateLimit({
6062
team: validTeamResponse,
@@ -72,11 +74,11 @@ describe("rateLimit", () => {
7274
errorCode: "RATE_LIMIT_EXCEEDED",
7375
});
7476

75-
expect(mockRedis.expire).not.toHaveBeenCalled();
77+
expect(mockRedis.incrBy).not.toHaveBeenCalled();
7678
});
7779

7880
it("expires on the first incr request only", async () => {
79-
mockRedis.incr.mockResolvedValue(1);
81+
mockRedis.get.mockResolvedValue("1");
8082

8183
const result = await rateLimit({
8284
team: validTeamResponse,
@@ -87,14 +89,14 @@ describe("rateLimit", () => {
8789

8890
expect(result).toEqual({
8991
rateLimited: false,
90-
requestCount: 1,
92+
requestCount: 2,
9193
rateLimit: 50,
9294
});
93-
expect(mockRedis.expire).toHaveBeenCalled();
95+
expect(mockRedis.incrBy).toHaveBeenCalled();
9496
});
9597

9698
it("enforces rate limit if sampled (hit)", async () => {
97-
mockRedis.incr.mockResolvedValue(10);
99+
mockRedis.get.mockResolvedValue("10");
98100
vi.spyOn(global.Math, "random").mockReturnValue(0.08);
99101

100102
const result = await rateLimit({
@@ -117,7 +119,7 @@ describe("rateLimit", () => {
117119
});
118120

119121
it("does not enforce rate limit if sampled (miss)", async () => {
120-
mockRedis.incr.mockResolvedValue(10);
122+
mockRedis.get.mockResolvedValue(10);
121123
vi.spyOn(global.Math, "random").mockReturnValue(0.15);
122124

123125
const result = await rateLimit({
@@ -134,4 +136,152 @@ describe("rateLimit", () => {
134136
rateLimit: 0,
135137
});
136138
});
139+
140+
it("should handle redis get failure gracefully", async () => {
141+
mockRedis.get.mockRejectedValue(new Error("Redis connection error"));
142+
143+
const result = await rateLimit({
144+
team: validTeamResponse,
145+
limitPerSecond: 5,
146+
serviceConfig: validServiceConfig,
147+
redis: mockRedis,
148+
});
149+
150+
expect(result).toEqual({
151+
rateLimited: false,
152+
requestCount: 1,
153+
rateLimit: 50,
154+
});
155+
});
156+
157+
it("should handle zero requests correctly", async () => {
158+
mockRedis.get.mockResolvedValue("0");
159+
160+
const result = await rateLimit({
161+
team: validTeamResponse,
162+
limitPerSecond: 5,
163+
serviceConfig: validServiceConfig,
164+
redis: mockRedis,
165+
});
166+
167+
expect(result).toEqual({
168+
rateLimited: false,
169+
requestCount: 1,
170+
rateLimit: 50,
171+
});
172+
expect(mockRedis.incrBy).toHaveBeenCalledWith(expect.any(String), 1);
173+
});
174+
175+
it("should handle null response from redis", async () => {
176+
mockRedis.get.mockResolvedValue(null);
177+
178+
const result = await rateLimit({
179+
team: validTeamResponse,
180+
limitPerSecond: 5,
181+
serviceConfig: validServiceConfig,
182+
redis: mockRedis,
183+
});
184+
185+
expect(result).toEqual({
186+
rateLimited: false,
187+
requestCount: 1,
188+
rateLimit: 50,
189+
});
190+
});
191+
192+
it("should handle very low sample rates", async () => {
193+
mockRedis.get.mockResolvedValue("100");
194+
vi.spyOn(global.Math, "random").mockReturnValue(0.001);
195+
196+
const result = await rateLimit({
197+
team: validTeamResponse,
198+
limitPerSecond: 5,
199+
serviceConfig: validServiceConfig,
200+
redis: mockRedis,
201+
sampleRate: 0.01,
202+
});
203+
204+
expect(result).toEqual({
205+
rateLimited: true,
206+
requestCount: 100,
207+
rateLimit: 0.5,
208+
status: 429,
209+
errorMessage: expect.any(String),
210+
errorCode: "RATE_LIMIT_EXCEEDED",
211+
});
212+
});
213+
214+
it("should handle multiple concurrent requests with redis lag", async () => {
215+
// Mock initial state
216+
mockRedis.get.mockResolvedValue("0");
217+
218+
// Mock redis.set to have 100ms delay
219+
mockRedis.incrBy.mockImplementation(
220+
() =>
221+
new Promise((resolve) => {
222+
setTimeout(() => resolve(1), 100);
223+
}),
224+
);
225+
226+
// Make 3 concurrent requests
227+
const requests = Promise.all([
228+
rateLimit({
229+
team: validTeamResponse,
230+
limitPerSecond: 5,
231+
serviceConfig: validServiceConfig,
232+
redis: mockRedis,
233+
}),
234+
rateLimit({
235+
team: validTeamResponse,
236+
limitPerSecond: 5,
237+
serviceConfig: validServiceConfig,
238+
redis: mockRedis,
239+
}),
240+
rateLimit({
241+
team: validTeamResponse,
242+
limitPerSecond: 5,
243+
serviceConfig: validServiceConfig,
244+
redis: mockRedis,
245+
}),
246+
]);
247+
248+
const results = await requests;
249+
// All requests should succeed since they all see initial count of 0
250+
for (const result of results) {
251+
expect(result).toEqual({
252+
rateLimited: false,
253+
requestCount: 1,
254+
rateLimit: 50,
255+
});
256+
}
257+
258+
// Redis set should be called 3 times
259+
expect(mockRedis.incrBy).toHaveBeenCalledTimes(3);
260+
});
261+
262+
it("should handle custom increment values", async () => {
263+
// Mock initial state
264+
mockRedis.get.mockResolvedValue("5");
265+
mockRedis.incrBy.mockResolvedValue(10);
266+
267+
const result = await rateLimit({
268+
team: validTeamResponse,
269+
limitPerSecond: 20,
270+
serviceConfig: validServiceConfig,
271+
redis: mockRedis,
272+
increment: 5,
273+
});
274+
275+
expect(result).toEqual({
276+
rateLimited: false,
277+
requestCount: 10,
278+
rateLimit: 200,
279+
});
280+
281+
// Verify redis was called with correct increment
282+
expect(mockRedis.incrBy).toHaveBeenCalledWith(
283+
expect.stringContaining("rate-limit"),
284+
5,
285+
);
286+
});
137287
});

0 commit comments

Comments
 (0)