Skip to content

Commit ba42c27

Browse files
authored
Merge pull request #110 from ant-xuexiao/rag_pr
feat: enable to add knowledge form git file or issues
2 parents 3258a19 + 8173741 commit ba42c27

File tree

6 files changed

+212
-23
lines changed

6 files changed

+212
-23
lines changed

server/.env.example

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ AUTH0_CLIENT_ID=auth0_client_id
2222
AUTH0_CLIENT_SECRET=auth0_client_secret
2323
API_URL=api_url
2424
WEB_URL=web_url
25+
26+
GITHUB_TOKEN=github_token # https://github.com/settings/tokens?type=beta

server/data_class.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Literal, Optional
22
from pydantic import BaseModel
33

44

@@ -23,3 +23,22 @@ class ExecuteMessage(BaseModel):
2323
class S3Config(BaseModel):
2424
s3_bucket: str
2525
file_path: Optional[str] = None
26+
27+
class GitIssueConfig(BaseModel):
28+
repo_name: str
29+
page: Optional[int] = None
30+
"""The page number for paginated results.
31+
Defaults to 1 in the GitHub API."""
32+
per_page: Optional[int] = 30
33+
"""Number of items per page.
34+
Defaults to 30 in the GitHub API."""
35+
state: Optional[Literal["open", "closed", "all"]] = 'all'
36+
"""Filter on issue state. Can be one of: 'open', 'closed', 'all'."""
37+
38+
39+
class GitDocConfig(BaseModel):
40+
repo_name: str
41+
file_path: str
42+
"""File path of the documentation file. eg:'docs/blog/build-ghost.zh-CN.md'"""
43+
branch: Optional[str] = 'main'
44+

server/rag/github_file_loader.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
2+
"""
3+
This file was originally sourced from the https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/document_loaders/github.py
4+
and it has been modified based on the requirements provided by petercat.
5+
"""
6+
7+
import base64
8+
from abc import ABC
9+
from typing import Callable, Dict, Iterator, Optional
10+
import requests
11+
from langchain_core.documents import Document
12+
from langchain_core.pydantic_v1 import BaseModel, root_validator
13+
from langchain_core.utils import get_from_dict_or_env
14+
15+
from langchain_community.document_loaders.base import BaseLoader
16+
17+
18+
class BaseGitHubLoader(BaseLoader, BaseModel, ABC):
19+
"""Load `GitHub` repository Issues. """
20+
21+
repo: str
22+
"""Name of repository"""
23+
access_token: str
24+
"""Personal access token - see https://github.com/settings/tokens?type=beta"""
25+
github_api_url: str = "https://api.github.com"
26+
"""URL of GitHub API"""
27+
28+
@root_validator(pre=True, allow_reuse=True)
29+
def validate_environment(cls, values: Dict) -> Dict:
30+
"""Validate that access token exists in environment."""
31+
values["access_token"] = get_from_dict_or_env(
32+
values, "access_token", "GITHUB_PERSONAL_ACCESS_TOKEN"
33+
)
34+
return values
35+
36+
@property
37+
def headers(self) -> Dict[str, str]:
38+
return {
39+
"Accept": "application/vnd.github+json",
40+
"Authorization": f"Bearer {self.access_token}",
41+
}
42+
43+
44+
class GithubFileLoader(BaseGitHubLoader, ABC):
45+
"""Load GitHub File"""
46+
file_path: str
47+
file_extension: str = ".md"
48+
branch: str = "main"
49+
file_filter: Optional[Callable[[str], bool]]
50+
51+
def get_file_content_by_path(self, path: str) -> str:
52+
base_url = f"{self.github_api_url}/repos/{self.repo}/contents/{path}?ref={self.branch}"
53+
response = requests.get(base_url, headers=self.headers)
54+
response.raise_for_status()
55+
56+
if isinstance(response.json(), dict):
57+
content_encoded = response.json()["content"]
58+
return base64.b64decode(content_encoded).decode("utf-8")
59+
60+
return ""
61+
62+
def load(self) -> Iterator[Document]:
63+
content = self.get_file_content_by_path(self.file_path)
64+
65+
metadata = {
66+
"path": self.file_path,
67+
"source": f"{self.github_api_url}/{self.repo}/blob/"
68+
f"{self.branch}/{self.file_path}",
69+
}
70+
yield Document(page_content=content, metadata=metadata)

server/rag/retrieval.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import json
2+
from typing import Optional
23
from langchain_openai import OpenAIEmbeddings
34
from langchain_community.vectorstores import SupabaseVectorStore
45
from db.supabase.client import get_client
5-
from data_class import S3Config
6+
from data_class import GitDocConfig, GitIssueConfig, S3Config
7+
from rag.github_file_loader import GithubFileLoader
68
from uilts.env import get_env_variable
79

10+
811
supabase_url = get_env_variable("SUPABASE_URL")
912
supabase_key = get_env_variable("SUPABASE_SERVICE_KEY")
13+
ACCESS_TOKEN=get_env_variable("GITHUB_TOKEN")
1014

1115

12-
table_name="antd_knowledge"
13-
query_name="match_antd_knowledge"
14-
chunk_size=2000
15-
16+
TABLE_NAME="rag_docs"
17+
QUERY_NAME="match_rag_docs"
18+
CHUNK_SIZE=2000
19+
CHUNK_OVERLAP=20
1620

1721
def convert_document_to_dict(document):
1822
return document.page_content,
@@ -23,37 +27,84 @@ def init_retriever():
2327
db = SupabaseVectorStore(
2428
embedding=embeddings,
2529
client=get_client(),
26-
table_name=table_name,
27-
query_name=query_name,
28-
chunk_size=chunk_size,
30+
table_name=TABLE_NAME,
31+
query_name=QUERY_NAME,
32+
chunk_size=CHUNK_SIZE,
2933
)
3034

3135
return db.as_retriever()
3236

3337

34-
def add_knowledge(config: S3Config):
38+
def init_s3_Loader(config: S3Config):
3539
from langchain_community.document_loaders import S3DirectoryLoader
40+
loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path)
41+
return loader
42+
43+
def init_github_issue_loader(config: GitIssueConfig):
44+
from langchain_community.document_loaders import GitHubIssuesLoader
45+
46+
loader = GitHubIssuesLoader(
47+
repo=config.repo_name,
48+
access_token=ACCESS_TOKEN,
49+
page=config.page,
50+
per_page=config.per_page,
51+
state=config.state
52+
)
53+
return loader
54+
def init_github_file_loader(config: GitDocConfig):
55+
loader = GithubFileLoader(
56+
repo=config.repo_name,
57+
access_token=ACCESS_TOKEN,
58+
github_api_url="https://api.github.com",
59+
branch=config.branch,
60+
file_path=config.file_path,
61+
file_filter=lambda file_path: file_path.endswith(".md")
62+
)
63+
return loader
64+
65+
def supabase_embedding(documents):
3666
from langchain_text_splitters import CharacterTextSplitter
37-
38-
try:
39-
loader = S3DirectoryLoader(config.s3_bucket, prefix=config.file_path)
40-
documents = loader.load()
41-
text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=0)
67+
68+
try:
69+
text_splitter = CharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
4270
docs = text_splitter.split_documents(documents)
4371
embeddings = OpenAIEmbeddings()
4472
SupabaseVectorStore.from_documents(
4573
docs,
4674
embeddings,
4775
client=get_client(),
48-
table_name=table_name,
49-
query_name=query_name,
50-
chunk_size=chunk_size,
76+
table_name=TABLE_NAME,
77+
query_name=QUERY_NAME,
78+
chunk_size=CHUNK_SIZE,
5179
)
5280
return json.dumps({
5381
"success": True,
5482
"message": "Knowledge added successfully!",
5583
"docs_len": len(documents)
84+
})
85+
except Exception as e:
86+
return json.dumps({
87+
"success": False,
88+
"message": str(e)
5689
})
90+
91+
92+
def add_knowledge_by_issues(config: GitIssueConfig):
93+
try:
94+
loader = init_github_issue_loader(config)
95+
documents = loader.load()
96+
supabase_embedding(documents)
97+
except Exception as e:
98+
return json.dumps({
99+
"success": False,
100+
"message": str(e)
101+
})
102+
103+
def add_knowledge_by_doc(config: GitDocConfig):
104+
try:
105+
loader = init_github_file_loader(config)
106+
documents = loader.load()
107+
supabase_embedding(documents)
57108
except Exception as e:
58109
return json.dumps({
59110
"success": False,

server/routers/rag.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Optional
12
from fastapi import APIRouter, Depends
23
from rag import retrieval
3-
from data_class import S3Config
4+
from data_class import GitDocConfig, GitIssueConfig, S3Config
45
from verify.rate_limit import verify_rate_limit
56

67
router = APIRouter(
@@ -10,12 +11,17 @@
1011
)
1112

1213

13-
@router.post("/rag/add_knowledge", dependencies=[Depends(verify_rate_limit)])
14-
def add_knowledge(config: S3Config):
15-
data=retrieval.add_knowledge(config)
14+
@router.post("/rag/add_knowledge_by_doc")
15+
def add_knowledge_by_doc(config: GitDocConfig):
16+
data=retrieval.add_knowledge_by_doc(config)
1617
return data
1718

18-
@router.post("/rag/search_knowledge", dependencies=[Depends(verify_rate_limit)])
19+
@router.post("/rag/add_knowledge_by_issues")
20+
def add_knowledge_by_issues(config: GitIssueConfig):
21+
data=retrieval.add_knowledge_by_issues(config)
22+
return data
23+
24+
@router.post("/rag/search_knowledge")
1925
def search_knowledge(query: str):
2026
data=retrieval.search_knowledge(query)
2127
return data

server/sql/rag_docs.sql

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
-- Enable the pgvector extension to work with embedding vectors
2+
create extension
3+
if not exists vector;
4+
5+
-- Create a table to store your rag_docs
6+
create table rag_docs
7+
(
8+
id uuid primary key,
9+
content text,
10+
-- corresponds to Document.pageContent
11+
metadata jsonb,
12+
-- corresponds to Document.metadata
13+
embedding vector (1536)
14+
-- 1536 works for OpenAI embeddings, change if needed
15+
);
16+
17+
-- Create a function to search for rag_docs
18+
create function match_rag_docs(
19+
query_embedding vector (1536),
20+
filter jsonb default '{}'
21+
) returns table
22+
(
23+
id uuid,
24+
content text,
25+
metadata jsonb,
26+
similarity float
27+
) language plpgsql as $$
28+
#variable_conflict use_column
29+
begin
30+
return query
31+
select
32+
id,
33+
content,
34+
metadata,
35+
1 - (rag_docs.embedding <=> query_embedding
36+
) as similarity
37+
from rag_docs
38+
where metadata @> filter
39+
order by rag_docs.embedding <=> query_embedding;
40+
end;
41+
$$;

0 commit comments

Comments
 (0)