From 2598a1d00a54e8ec0f4c472ecbeb0369df48927b Mon Sep 17 00:00:00 2001 From: ch-liuzhide Date: Wed, 2 Apr 2025 17:34:12 +0800 Subject: [PATCH] feat(rag): add api /retrieval/bot --- server/bot/builder.py | 19 +++++----- server/rag/router.py | 80 ++++++++++++++++++++++++++++++++--------- server/requirements.txt | 2 +- 3 files changed, 76 insertions(+), 25 deletions(-) diff --git a/server/bot/builder.py b/server/bot/builder.py index ad6ade1a..4782ec44 100644 --- a/server/bot/builder.py +++ b/server/bot/builder.py @@ -4,8 +4,9 @@ from core.models.user import User from whiskerrag_client import APIClient from whiskerrag_types.model import ( - KnowledgeCreate, - KnowledgeSplitConfig, + GithubRepoCreate, + BaseCharSplitConfig, + EmbeddingModelEnum, KnowledgeSourceEnum, KnowledgeTypeEnum, GithubRepoSourceConfig, @@ -81,17 +82,19 @@ async def bot_builder( ) await api_client.knowledge.add_knowledge( [ - KnowledgeCreate( - source_type=KnowledgeSourceEnum.GITHUB_REPO, - knowledge_type=KnowledgeTypeEnum.FOLDER, + GithubRepoCreate( space_id=repo_name, + knowledge_type=KnowledgeTypeEnum.FOLDER, knowledge_name=repo_name, + metadata={}, + embedding_model_name=EmbeddingModelEnum.OPENAI, + source_type=KnowledgeSourceEnum.GITHUB_REPO, source_config=GithubRepoSourceConfig( repo_name=repo_name, auth_token=user.access_token ), - split_config=KnowledgeSplitConfig( - chunk_size=500, - chunk_overlap=100, + split_config=BaseCharSplitConfig( + chunk_size=1500, + chunk_overlap=200, ), ) ] diff --git a/server/rag/router.py b/server/rag/router.py index 7ba0006a..7f91a411 100644 --- a/server/rag/router.py +++ b/server/rag/router.py @@ -1,8 +1,9 @@ import json -from typing import Annotated, List +from typing import Annotated, List, Optional from fastapi import APIRouter, Depends, HTTPException, status from openai import BaseModel +from pydantic import Field from auth.get_user_info import get_user from core.models.user import User from utils.env import get_env_variable @@ -12,14 +13,28 @@ Knowledge, Task, Chunk, - KnowledgeCreate, + GithubRepoCreate, KnowledgeTypeEnum, KnowledgeSourceEnum, GithubRepoSourceConfig, - KnowledgeSplitConfig, + EmbeddingModelEnum, + BaseCharSplitConfig, + RetrievalChunk, + RetrievalBySpaceRequest ) from auth.rate_limit import verify_rate_limit + +class RetrievalByBotRequest(BaseModel): + question: str = Field(..., description="question") + bot_id_list: List[str] = Field([], description="petercat bot id") + repo_name_list:List[str] = Field([], description="github repo name list") + top: Optional[int] = Field(10, description="top k", ge=1, le=500) + similarity_threshold: Optional[float] = Field(0.6, description="similarity threshold",ge=0.0,le=1.0) + metadata_filter:Optional[dict]=Field({}, description="metadata filter") + class Config: + extra = "forbid" + router = APIRouter( prefix="/api/rag", tags=["rag"], @@ -51,19 +66,21 @@ async def reload_repo( ) res = await api_client.knowledge.add_knowledge( [ - KnowledgeCreate( - source_type=KnowledgeSourceEnum.GITHUB_REPO, - knowledge_type=KnowledgeTypeEnum.FOLDER, - space_id=request.repo_name, - knowledge_name=request.repo_name, - source_config=GithubRepoSourceConfig( - repo_name=request.repo_name, auth_token=user.access_token - ), - split_config=KnowledgeSplitConfig( - chunk_size=1000, - chunk_overlap=200, - ), - ) + GithubRepoCreate( + source_type=KnowledgeSourceEnum.GITHUB_REPO, + knowledge_type=KnowledgeTypeEnum.FOLDER, + + space_id=request.repo_name, + knowledge_name=request.repo_name, + embedding_model_name=EmbeddingModelEnum.OPENAI, + source_config=GithubRepoSourceConfig( + repo_name=request.repo_name, auth_token=user.access_token + ), + split_config=BaseCharSplitConfig( + chunk_size=1500, + chunk_overlap=200, + ), + ) ] ) return res @@ -145,3 +162,34 @@ async def restart_rag_task( return res except Exception as e: return json.dumps({"success": False, "message": str(e)}) + + +@router.post("/retrieval/bot", dependencies=[Depends(verify_rate_limit)]) +async def retrievalBot( + params: RetrievalByBotRequest, + user: Annotated[User | None, Depends(get_user)] = None, +)->List[RetrievalChunk]: + bot_id_list = params.bot_id_list + repo_name_list = params.repo_name_list + space_id_list = repo_name_list + bot_id_list + if len(space_id_list) == 0: + raise HTTPException( + status_code=400, + detail="At least one of bot_id_list and repo_name_list must not be empty", + ) + api_client = APIClient( + base_url=get_env_variable("WHISKER_API_URL"), + token=get_env_variable("WHISKER_API_KEY"), + timeout=30, + ) + retrieval_res = await api_client.retrieval.retrieve_space_content( + RetrievalBySpaceRequest( + space_id_list=space_id_list, + question=params.question, + embedding_model_name=EmbeddingModelEnum.OPENAI, + similarity_threshold=params.similarity_threshold, + top=params.top, + metadata_filter=params.metadata_filter, + ) + ) + return retrieval_res diff --git a/server/requirements.txt b/server/requirements.txt index fd69ebe3..3e577821 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -26,4 +26,4 @@ requests httpx==0.27.2 urllib3>=2.2.2 toolz -whiskerrag>=0.0.15 \ No newline at end of file +whiskerrag>=0.0.27 \ No newline at end of file