|
1 | 1 | import json
|
2 |
| -from typing import Annotated, List |
| 2 | +from typing import Annotated, List, Optional |
3 | 3 |
|
4 | 4 | from fastapi import APIRouter, Depends, HTTPException, status
|
5 | 5 | from openai import BaseModel
|
| 6 | +from pydantic import Field |
6 | 7 | from auth.get_user_info import get_user
|
7 | 8 | from core.models.user import User
|
8 | 9 | from utils.env import get_env_variable
|
|
12 | 13 | Knowledge,
|
13 | 14 | Task,
|
14 | 15 | Chunk,
|
15 |
| - KnowledgeCreate, |
| 16 | + GithubRepoCreate, |
16 | 17 | KnowledgeTypeEnum,
|
17 | 18 | KnowledgeSourceEnum,
|
18 | 19 | GithubRepoSourceConfig,
|
19 |
| - KnowledgeSplitConfig, |
| 20 | + EmbeddingModelEnum, |
| 21 | + BaseCharSplitConfig, |
| 22 | + RetrievalChunk, |
| 23 | + RetrievalBySpaceRequest |
20 | 24 | )
|
21 | 25 | from auth.rate_limit import verify_rate_limit
|
22 | 26 |
|
| 27 | + |
| 28 | +class RetrievalByBotRequest(BaseModel): |
| 29 | + question: str = Field(..., description="question") |
| 30 | + bot_id_list: List[str] = Field([], description="petercat bot id") |
| 31 | + repo_name_list:List[str] = Field([], description="github repo name list") |
| 32 | + top: Optional[int] = Field(10, description="top k", ge=1, le=500) |
| 33 | + similarity_threshold: Optional[float] = Field(0.6, description="similarity threshold",ge=0.0,le=1.0) |
| 34 | + metadata_filter:Optional[dict]=Field({}, description="metadata filter") |
| 35 | + class Config: |
| 36 | + extra = "forbid" |
| 37 | + |
23 | 38 | router = APIRouter(
|
24 | 39 | prefix="/api/rag",
|
25 | 40 | tags=["rag"],
|
@@ -51,19 +66,21 @@ async def reload_repo(
|
51 | 66 | )
|
52 | 67 | res = await api_client.knowledge.add_knowledge(
|
53 | 68 | [
|
54 |
| - KnowledgeCreate( |
55 |
| - source_type=KnowledgeSourceEnum.GITHUB_REPO, |
56 |
| - knowledge_type=KnowledgeTypeEnum.FOLDER, |
57 |
| - space_id=request.repo_name, |
58 |
| - knowledge_name=request.repo_name, |
59 |
| - source_config=GithubRepoSourceConfig( |
60 |
| - repo_name=request.repo_name, auth_token=user.access_token |
61 |
| - ), |
62 |
| - split_config=KnowledgeSplitConfig( |
63 |
| - chunk_size=1000, |
64 |
| - chunk_overlap=200, |
65 |
| - ), |
66 |
| - ) |
| 69 | + GithubRepoCreate( |
| 70 | + source_type=KnowledgeSourceEnum.GITHUB_REPO, |
| 71 | + knowledge_type=KnowledgeTypeEnum.FOLDER, |
| 72 | + |
| 73 | + space_id=request.repo_name, |
| 74 | + knowledge_name=request.repo_name, |
| 75 | + embedding_model_name=EmbeddingModelEnum.OPENAI, |
| 76 | + source_config=GithubRepoSourceConfig( |
| 77 | + repo_name=request.repo_name, auth_token=user.access_token |
| 78 | + ), |
| 79 | + split_config=BaseCharSplitConfig( |
| 80 | + chunk_size=1500, |
| 81 | + chunk_overlap=200, |
| 82 | + ), |
| 83 | + ) |
67 | 84 | ]
|
68 | 85 | )
|
69 | 86 | return res
|
@@ -145,3 +162,34 @@ async def restart_rag_task(
|
145 | 162 | return res
|
146 | 163 | except Exception as e:
|
147 | 164 | return json.dumps({"success": False, "message": str(e)})
|
| 165 | + |
| 166 | + |
| 167 | +@router.post("/retrieval/bot", dependencies=[Depends(verify_rate_limit)]) |
| 168 | +async def retrievalBot( |
| 169 | + params: RetrievalByBotRequest, |
| 170 | + user: Annotated[User | None, Depends(get_user)] = None, |
| 171 | +)->List[RetrievalChunk]: |
| 172 | + bot_id_list = params.bot_id_list |
| 173 | + repo_name_list = params.repo_name_list |
| 174 | + space_id_list = repo_name_list + bot_id_list |
| 175 | + if len(space_id_list) == 0: |
| 176 | + raise HTTPException( |
| 177 | + status_code=400, |
| 178 | + detail="At least one of bot_id_list and repo_name_list must not be empty", |
| 179 | + ) |
| 180 | + api_client = APIClient( |
| 181 | + base_url=get_env_variable("WHISKER_API_URL"), |
| 182 | + token=get_env_variable("WHISKER_API_KEY"), |
| 183 | + timeout=30, |
| 184 | + ) |
| 185 | + retrieval_res = await api_client.retrieval.retrieve_space_content( |
| 186 | + RetrievalBySpaceRequest( |
| 187 | + space_id_list=space_id_list, |
| 188 | + question=params.question, |
| 189 | + embedding_model_name=EmbeddingModelEnum.OPENAI, |
| 190 | + similarity_threshold=params.similarity_threshold, |
| 191 | + top=params.top, |
| 192 | + metadata_filter=params.metadata_filter, |
| 193 | + ) |
| 194 | + ) |
| 195 | + return retrieval_res |
0 commit comments