Skip to content

Commit 4cfe7bc

Browse files
authored
feat(rag): add api /retrieval/bot (#794)
1 parent 9968251 commit 4cfe7bc

File tree

3 files changed

+76
-25
lines changed

3 files changed

+76
-25
lines changed

server/bot/builder.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from core.models.user import User
55
from whiskerrag_client import APIClient
66
from whiskerrag_types.model import (
7-
KnowledgeCreate,
8-
KnowledgeSplitConfig,
7+
GithubRepoCreate,
8+
BaseCharSplitConfig,
9+
EmbeddingModelEnum,
910
KnowledgeSourceEnum,
1011
KnowledgeTypeEnum,
1112
GithubRepoSourceConfig,
@@ -81,17 +82,19 @@ async def bot_builder(
8182
)
8283
await api_client.knowledge.add_knowledge(
8384
[
84-
KnowledgeCreate(
85-
source_type=KnowledgeSourceEnum.GITHUB_REPO,
86-
knowledge_type=KnowledgeTypeEnum.FOLDER,
85+
GithubRepoCreate(
8786
space_id=repo_name,
87+
knowledge_type=KnowledgeTypeEnum.FOLDER,
8888
knowledge_name=repo_name,
89+
metadata={},
90+
embedding_model_name=EmbeddingModelEnum.OPENAI,
91+
source_type=KnowledgeSourceEnum.GITHUB_REPO,
8992
source_config=GithubRepoSourceConfig(
9093
repo_name=repo_name, auth_token=user.access_token
9194
),
92-
split_config=KnowledgeSplitConfig(
93-
chunk_size=500,
94-
chunk_overlap=100,
95+
split_config=BaseCharSplitConfig(
96+
chunk_size=1500,
97+
chunk_overlap=200,
9598
),
9699
)
97100
]

server/rag/router.py

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import json
2-
from typing import Annotated, List
2+
from typing import Annotated, List, Optional
33

44
from fastapi import APIRouter, Depends, HTTPException, status
55
from openai import BaseModel
6+
from pydantic import Field
67
from auth.get_user_info import get_user
78
from core.models.user import User
89
from utils.env import get_env_variable
@@ -12,14 +13,28 @@
1213
Knowledge,
1314
Task,
1415
Chunk,
15-
KnowledgeCreate,
16+
GithubRepoCreate,
1617
KnowledgeTypeEnum,
1718
KnowledgeSourceEnum,
1819
GithubRepoSourceConfig,
19-
KnowledgeSplitConfig,
20+
EmbeddingModelEnum,
21+
BaseCharSplitConfig,
22+
RetrievalChunk,
23+
RetrievalBySpaceRequest
2024
)
2125
from auth.rate_limit import verify_rate_limit
2226

27+
28+
class RetrievalByBotRequest(BaseModel):
29+
question: str = Field(..., description="question")
30+
bot_id_list: List[str] = Field([], description="petercat bot id")
31+
repo_name_list:List[str] = Field([], description="github repo name list")
32+
top: Optional[int] = Field(10, description="top k", ge=1, le=500)
33+
similarity_threshold: Optional[float] = Field(0.6, description="similarity threshold",ge=0.0,le=1.0)
34+
metadata_filter:Optional[dict]=Field({}, description="metadata filter")
35+
class Config:
36+
extra = "forbid"
37+
2338
router = APIRouter(
2439
prefix="/api/rag",
2540
tags=["rag"],
@@ -51,19 +66,21 @@ async def reload_repo(
5166
)
5267
res = await api_client.knowledge.add_knowledge(
5368
[
54-
KnowledgeCreate(
55-
source_type=KnowledgeSourceEnum.GITHUB_REPO,
56-
knowledge_type=KnowledgeTypeEnum.FOLDER,
57-
space_id=request.repo_name,
58-
knowledge_name=request.repo_name,
59-
source_config=GithubRepoSourceConfig(
60-
repo_name=request.repo_name, auth_token=user.access_token
61-
),
62-
split_config=KnowledgeSplitConfig(
63-
chunk_size=1000,
64-
chunk_overlap=200,
65-
),
66-
)
69+
GithubRepoCreate(
70+
source_type=KnowledgeSourceEnum.GITHUB_REPO,
71+
knowledge_type=KnowledgeTypeEnum.FOLDER,
72+
73+
space_id=request.repo_name,
74+
knowledge_name=request.repo_name,
75+
embedding_model_name=EmbeddingModelEnum.OPENAI,
76+
source_config=GithubRepoSourceConfig(
77+
repo_name=request.repo_name, auth_token=user.access_token
78+
),
79+
split_config=BaseCharSplitConfig(
80+
chunk_size=1500,
81+
chunk_overlap=200,
82+
),
83+
)
6784
]
6885
)
6986
return res
@@ -145,3 +162,34 @@ async def restart_rag_task(
145162
return res
146163
except Exception as e:
147164
return json.dumps({"success": False, "message": str(e)})
165+
166+
167+
@router.post("/retrieval/bot", dependencies=[Depends(verify_rate_limit)])
168+
async def retrievalBot(
169+
params: RetrievalByBotRequest,
170+
user: Annotated[User | None, Depends(get_user)] = None,
171+
)->List[RetrievalChunk]:
172+
bot_id_list = params.bot_id_list
173+
repo_name_list = params.repo_name_list
174+
space_id_list = repo_name_list + bot_id_list
175+
if len(space_id_list) == 0:
176+
raise HTTPException(
177+
status_code=400,
178+
detail="At least one of bot_id_list and repo_name_list must not be empty",
179+
)
180+
api_client = APIClient(
181+
base_url=get_env_variable("WHISKER_API_URL"),
182+
token=get_env_variable("WHISKER_API_KEY"),
183+
timeout=30,
184+
)
185+
retrieval_res = await api_client.retrieval.retrieve_space_content(
186+
RetrievalBySpaceRequest(
187+
space_id_list=space_id_list,
188+
question=params.question,
189+
embedding_model_name=EmbeddingModelEnum.OPENAI,
190+
similarity_threshold=params.similarity_threshold,
191+
top=params.top,
192+
metadata_filter=params.metadata_filter,
193+
)
194+
)
195+
return retrieval_res

server/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@ requests
2626
httpx==0.27.2
2727
urllib3>=2.2.2
2828
toolz
29-
whiskerrag>=0.0.15
29+
whiskerrag>=0.0.27

0 commit comments

Comments
 (0)