Skip to content

Commit d91b18f

Browse files
authored
Merge pull request #107 from ant-xuexiao/feat/implement-rate-limit-and-throw-429-if-exceeded
feat: impl rate limit and throw 429 if too many requests
2 parents c46ba0b + c34a706 commit d91b18f

File tree

6 files changed

+90
-32
lines changed

6 files changed

+90
-32
lines changed

server/auth/get_user_info.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from uilts.env import get_env_variable
2+
import httpx
3+
import secrets
4+
5+
6+
AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN")
7+
8+
async def getUserInfoByToken(token):
9+
userinfo_url = f"https://{AUTH0_DOMAIN}/userinfo"
10+
11+
headers = {"authorization": f"Bearer {token}"}
12+
async with httpx.AsyncClient() as client:
13+
user_info_response = await client.get(userinfo_url, headers=headers)
14+
if user_info_response.status_code == 200:
15+
user_info = user_info_response.json()
16+
data = {
17+
"id": user_info["sub"],
18+
"nickname": user_info.get("nickname"),
19+
"name": user_info.get("name"),
20+
"picture": user_info.get("picture"),
21+
"sub": user_info["sub"],
22+
"sid": secrets.token_urlsafe(32)
23+
}
24+
return data
25+
else :
26+
return {}

server/routers/auth.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from fastapi import APIRouter,Cookie, Request, HTTPException, status, Response
2-
from uilts.env import get_env_variable
3-
from fastapi_auth0 import Auth0
2+
43
from fastapi.responses import RedirectResponse
54
import httpx
5+
66
from db.supabase.client import get_client
7-
import secrets
7+
from auth.get_user_info import getUserInfoByToken
8+
from uilts.env import get_env_variable
89

910
AUTH0_DOMAIN = get_env_variable("AUTH0_DOMAIN")
1011

@@ -19,34 +20,12 @@
1920
WEB_URL = get_env_variable("WEB_URL")
2021

2122

22-
auth = Auth0(domain=AUTH0_DOMAIN, api_audience=API_AUDIENCE, scopes={'read': 'get list'})
23-
2423
router = APIRouter(
2524
prefix="/api/auth",
2625
tags=["auth"],
2726
responses={404: {"description": "Not found"}},
2827
)
2928

30-
async def getUserInfoByToken(token):
31-
userinfo_url = f"https://{AUTH0_DOMAIN}/userinfo"
32-
33-
34-
headers = {"authorization": f"Bearer {token}"}
35-
async with httpx.AsyncClient() as client:
36-
user_info_response = await client.get(userinfo_url, headers=headers)
37-
if user_info_response.status_code == 200:
38-
user_info = user_info_response.json()
39-
data = {
40-
"id": user_info["sub"],
41-
"nickname": user_info.get("nickname"),
42-
"name": user_info.get("name"),
43-
"picture": user_info.get("picture"),
44-
"sub": user_info["sub"],
45-
"sid": secrets.token_urlsafe(32)
46-
}
47-
return data
48-
else :
49-
return {}
5029

5130
async def getTokenByCode(code):
5231
token_url = f"https://{AUTH0_DOMAIN}/oauth/token"

server/routers/chat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from fastapi import APIRouter
1+
from fastapi import APIRouter, Depends
22
from fastapi.responses import StreamingResponse
33
from data_class import ChatData
44
from agent import qa_chat, bot_builder
5+
from verify.rate_limit import verify_rate_limit
56

67

78
router = APIRouter(
@@ -10,12 +11,12 @@
1011
responses={404: {"description": "Not found"}},
1112
)
1213

13-
@router.post("/qa", response_class=StreamingResponse)
14+
@router.post("/qa", response_class=StreamingResponse, dependencies=[Depends(verify_rate_limit)])
1415
def run_qa_chat(input_data: ChatData):
1516
result = qa_chat.agent_chat(input_data)
1617
return StreamingResponse(result, media_type="text/event-stream")
1718

18-
@router.post("/builder", response_class=StreamingResponse)
19+
@router.post("/builder", response_class=StreamingResponse, dependencies=[Depends(verify_rate_limit)])
1920
def run_bot_builder(input_data: ChatData):
2021
result = bot_builder.agent_chat(input_data)
2122
return StreamingResponse(result, media_type="text/event-stream")

server/routers/health_checker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from fastapi import APIRouter, Depends, HTTPException
1+
from fastapi import APIRouter, Depends
2+
3+
from verify.rate_limit import verify_rate_limit
24

35
router = APIRouter(
46
prefix="/api",
@@ -8,4 +10,8 @@
810

911
@router.get("/health_checker")
1012
def health_checker():
13+
return { "Hello": "World" }
14+
15+
@router.get("/login_checker", dependencies=[Depends(verify_rate_limit)])
16+
def login_checker():
1117
return { "Hello": "World" }

server/routers/rag.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from fastapi import APIRouter
1+
from fastapi import APIRouter, Depends
22
from rag import retrieval
33
from data_class import S3Config
4+
from verify.rate_limit import verify_rate_limit
45

56
router = APIRouter(
67
prefix="/api",
@@ -9,12 +10,12 @@
910
)
1011

1112

12-
@router.post("/rag/add_knowledge")
13+
@router.post("/rag/add_knowledge", dependencies=[Depends(verify_rate_limit)])
1314
def add_knowledge(config: S3Config):
1415
data=retrieval.add_knowledge(config)
1516
return data
1617

17-
@router.post("/rag/search_knowledge")
18+
@router.post("/rag/search_knowledge", dependencies=[Depends(verify_rate_limit)])
1819
def search_knowledge(query: str):
1920
data=retrieval.search_knowledge(query)
2021
return data

server/verify/rate_limit.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
from fastapi import Cookie, HTTPException
3+
from datetime import datetime, timedelta
4+
5+
from auth.get_user_info import getUserInfoByToken
6+
from db.supabase.client import get_client
7+
8+
RATE_LIMIT_REQUESTS = 100
9+
RATE_LIMIT_DURATION = timedelta(minutes=1)
10+
11+
async def verify_rate_limit(petercat: str = Cookie(None)):
12+
if not petercat:
13+
raise HTTPException(status_code=403, detail="Must Login")
14+
user = await getUserInfoByToken(petercat)
15+
user_id = user['id']
16+
17+
supabase = get_client()
18+
table = supabase.table("user_token_usage")
19+
rows = table.select('id, user_id, last_request, request_count').eq('user_id', user_id).execute()
20+
21+
now = datetime.now().isoformat()
22+
user_usage = rows.data[0] if len(rows.data) > 0 else { "user_id": user_id, 'request_count': 0, 'last_request': now }
23+
24+
# Calculate the time elapsed since the last request
25+
elapsed_time = datetime.now() - datetime.fromisoformat(user_usage["last_request"])
26+
27+
if elapsed_time > RATE_LIMIT_DURATION:
28+
# If the elapsed time is greater than the rate limit duration, reset the count
29+
user_usage['request_count'] = 1
30+
else:
31+
if user_usage['request_count'] >= RATE_LIMIT_REQUESTS:
32+
# If the request count exceeds the rate limit, return a JSON response with an error message
33+
raise HTTPException(
34+
status_code=429,
35+
detail="Rate Limit Exceeded, Try It Later",
36+
headers={"Retry-After": "60"}
37+
)
38+
39+
user_usage['request_count'] = int(user_usage['request_count']) + 1
40+
41+
user_usage['last_request'] = datetime.now().isoformat()
42+
43+
table.upsert(user_usage).execute()
44+
45+
return user

0 commit comments

Comments
 (0)