Skip to content

Commit 6eed115

Browse files
Feiueliuhua
and
liuhua
authored
Refactor API for document and session (infiniflow#2819)
### What problem does this PR solve? Refactor API for document and session. ### Type of change - [x] Refactoring --------- Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
1 parent 7d80fc4 commit 6eed115

File tree

9 files changed

+829
-753
lines changed

9 files changed

+829
-753
lines changed

api/apps/sdk/doc.py

Lines changed: 290 additions & 457 deletions
Large diffs are not rendered by default.

api/apps/sdk/session.py

Lines changed: 82 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -20,81 +20,77 @@
2020

2121
from api.db import StatusEnum
2222
from api.db.services.dialog_service import DialogService, ConversationService, chat
23-
from api.settings import RetCode
2423
from api.utils import get_uuid
25-
from api.utils.api_utils import get_data_error_result
26-
from api.utils.api_utils import get_json_result, token_required
24+
from api.utils.api_utils import get_error_data_result
25+
from api.utils.api_utils import get_result, token_required
2726

28-
29-
@manager.route('/save', methods=['POST'])
27+
@manager.route('/chat/<chat_id>/session', methods=['POST'])
3028
@token_required
31-
def set_conversation(tenant_id):
29+
def create(tenant_id,chat_id):
3230
req = request.json
33-
conv_id = req.get("id")
34-
if "assistant_id" in req:
35-
req["dialog_id"] = req.pop("assistant_id")
36-
if "id" in req:
37-
del req["id"]
38-
conv = ConversationService.query(id=conv_id)
39-
if not conv:
40-
return get_data_error_result(retmsg="Session does not exist")
41-
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
42-
return get_data_error_result(retmsg="You do not own the session")
43-
if req.get("dialog_id"):
44-
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
45-
if not dia:
46-
return get_data_error_result(retmsg="You do not own the assistant")
47-
if "dialog_id" in req and not req.get("dialog_id"):
48-
return get_data_error_result(retmsg="assistant_id can not be empty.")
49-
if "message" in req:
50-
return get_data_error_result(retmsg="message can not be change")
51-
if "reference" in req:
52-
return get_data_error_result(retmsg="reference can not be change")
53-
if "name" in req and not req.get("name"):
54-
return get_data_error_result(retmsg="name can not be empty.")
55-
if not ConversationService.update_by_id(conv_id, req):
56-
return get_data_error_result(retmsg="Session updates error")
57-
return get_json_result(data=True)
58-
59-
if not req.get("dialog_id"):
60-
return get_data_error_result(retmsg="assistant_id is required.")
31+
req["dialog_id"] = chat_id
6132
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
6233
if not dia:
63-
return get_data_error_result(retmsg="You do not own the assistant")
34+
return get_error_data_result(retmsg="You do not own the assistant")
6435
conv = {
6536
"id": get_uuid(),
6637
"dialog_id": req["dialog_id"],
6738
"name": req.get("name", "New session"),
6839
"message": [{"role": "assistant", "content": "Hi! I am your assistant,can I help you?"}]
6940
}
7041
if not conv.get("name"):
71-
return get_data_error_result(retmsg="name can not be empty.")
42+
return get_error_data_result(retmsg="Name can not be empty.")
7243
ConversationService.save(**conv)
7344
e, conv = ConversationService.get_by_id(conv["id"])
7445
if not e:
75-
return get_data_error_result(retmsg="Fail to new session!")
46+
return get_error_data_result(retmsg="Fail to create a session!")
7647
conv = conv.to_dict()
7748
conv['messages'] = conv.pop("message")
78-
conv["assistant_id"] = conv.pop("dialog_id")
49+
conv["chat_id"] = conv.pop("dialog_id")
7950
del conv["reference"]
80-
return get_json_result(data=conv)
51+
return get_result(data=conv)
8152

82-
83-
@manager.route('/completion', methods=['POST'])
53+
@manager.route('/chat/<chat_id>/session/<session_id>', methods=['PUT'])
8454
@token_required
85-
def completion(tenant_id):
55+
def update(tenant_id,chat_id,session_id):
56+
req = request.json
57+
if "dialog_id" in req and req.get("dialog_id") != chat_id:
58+
return get_error_data_result(retmsg="Can't change chat_id")
59+
if "chat_id" in req and req.get("chat_id") != chat_id:
60+
return get_error_data_result(retmsg="Can't change chat_id")
61+
req["dialog_id"] = chat_id
62+
conv_id = session_id
63+
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
64+
if not conv:
65+
return get_error_data_result(retmsg="Session does not exist")
66+
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
67+
return get_error_data_result(retmsg="You do not own the session")
68+
if "message" in req or "messages" in req:
69+
return get_error_data_result(retmsg="Message can not be change")
70+
if "reference" in req:
71+
return get_error_data_result(retmsg="Reference can not be change")
72+
if "name" in req and not req.get("name"):
73+
return get_error_data_result(retmsg="Name can not be empty.")
74+
if not ConversationService.update_by_id(conv_id, req):
75+
return get_error_data_result(retmsg="Session updates error")
76+
return get_result()
77+
78+
79+
@manager.route('/chat/<chat_id>/session/<session_id>/completion', methods=['POST'])
80+
@token_required
81+
def completion(tenant_id,chat_id,session_id):
8682
req = request.json
8783
# req = {"conversation_id": "9aaaca4c11d311efa461fa163e197198", "messages": [
8884
# {"role": "user", "content": "上海有吗?"}
8985
# ]}
90-
if "session_id" not in req:
91-
return get_data_error_result(retmsg="session_id is required")
92-
conv = ConversationService.query(id=req["session_id"])
86+
if not req.get("question"):
87+
return get_error_data_result(retmsg="Please input your question.")
88+
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
9389
if not conv:
94-
return get_data_error_result(retmsg="Session does not exist")
90+
return get_error_data_result(retmsg="Session does not exist")
9591
conv = conv[0]
96-
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
97-
return get_data_error_result(retmsg="You do not own the session")
92+
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
93+
return get_error_data_result(retmsg="You do not own the session")
9894
msg = []
9995
question = {
10096
"content": req.get("question"),
@@ -108,7 +104,6 @@ def completion(tenant_id):
108104
msg.append(m)
109105
message_id = msg[-1].get("id")
110106
e, dia = DialogService.get_by_id(conv.dialog_id)
111-
del req["session_id"]
112107

113108
if not conv.reference:
114109
conv.reference = []
@@ -130,13 +125,13 @@ def stream():
130125
try:
131126
for ans in chat(dia, msg, **req):
132127
fillin_conv(ans)
133-
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
128+
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
134129
ConversationService.update_by_id(conv.id, conv.to_dict())
135130
except Exception as e:
136-
yield "data:" + json.dumps({"retcode": 500, "retmsg": str(e),
131+
yield "data:" + json.dumps({"code": 500, "message": str(e),
137132
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
138133
ensure_ascii=False) + "\n\n"
139-
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": True}, ensure_ascii=False) + "\n\n"
134+
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
140135

141136
if req.get("stream", True):
142137
resp = Response(stream(), mimetype="text/event-stream")
@@ -153,73 +148,31 @@ def stream():
153148
fillin_conv(ans)
154149
ConversationService.update_by_id(conv.id, conv.to_dict())
155150
break
156-
return get_json_result(data=answer)
157-
151+
return get_result(data=answer)
158152

159-
@manager.route('/get', methods=['GET'])
153+
@manager.route('/chat/<chat_id>/session', methods=['GET'])
160154
@token_required
161-
def get(tenant_id):
162-
req = request.args
163-
if "id" not in req:
164-
return get_data_error_result(retmsg="id is required")
165-
conv_id = req["id"]
166-
conv = ConversationService.query(id=conv_id)
167-
if not conv:
168-
return get_data_error_result(retmsg="Session does not exist")
169-
if not DialogService.query(id=conv[0].dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
170-
return get_data_error_result(retmsg="You do not own the session")
171-
if "assistant_id" in req:
172-
if req["assistant_id"] != conv[0].dialog_id:
173-
return get_data_error_result(retmsg="The session doesn't belong to the assistant")
174-
conv = conv[0].to_dict()
175-
conv['messages'] = conv.pop("message")
176-
conv["assistant_id"] = conv.pop("dialog_id")
177-
if conv["reference"]:
178-
messages = conv["messages"]
179-
message_num = 0
180-
chunk_num = 0
181-
while message_num < len(messages):
182-
if message_num != 0 and messages[message_num]["role"] != "user":
183-
chunk_list = []
184-
if "chunks" in conv["reference"][chunk_num]:
185-
chunks = conv["reference"][chunk_num]["chunks"]
186-
for chunk in chunks:
187-
new_chunk = {
188-
"id": chunk["chunk_id"],
189-
"content": chunk["content_with_weight"],
190-
"document_id": chunk["doc_id"],
191-
"document_name": chunk["docnm_kwd"],
192-
"knowledgebase_id": chunk["kb_id"],
193-
"image_id": chunk["img_id"],
194-
"similarity": chunk["similarity"],
195-
"vector_similarity": chunk["vector_similarity"],
196-
"term_similarity": chunk["term_similarity"],
197-
"positions": chunk["positions"],
198-
}
199-
chunk_list.append(new_chunk)
200-
chunk_num += 1
201-
messages[message_num]["reference"] = chunk_list
202-
message_num += 1
203-
del conv["reference"]
204-
return get_json_result(data=conv)
205-
206-
207-
@manager.route('/list', methods=["GET"])
208-
@token_required
209-
def list(tenant_id):
210-
assistant_id = request.args["assistant_id"]
211-
if not DialogService.query(tenant_id=tenant_id, id=assistant_id, status=StatusEnum.VALID.value):
212-
return get_json_result(
213-
data=False, retmsg=f"You don't own the assistant.",
214-
retcode=RetCode.OPERATING_ERROR)
215-
convs = ConversationService.query(
216-
dialog_id=assistant_id,
217-
order_by=ConversationService.model.create_time,
218-
reverse=True)
219-
convs = [d.to_dict() for d in convs]
155+
def list(chat_id,tenant_id):
156+
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
157+
return get_error_data_result(retmsg=f"You don't own the assistant {chat_id}.")
158+
id = request.args.get("id")
159+
name = request.args.get("name")
160+
session = ConversationService.query(id=id,name=name,dialog_id=chat_id)
161+
if not session:
162+
return get_error_data_result(retmsg="The session doesn't exist")
163+
page_number = int(request.args.get("page", 1))
164+
items_per_page = int(request.args.get("page_size", 1024))
165+
orderby = request.args.get("orderby", "create_time")
166+
if request.args.get("desc") == "False":
167+
desc = False
168+
else:
169+
desc = True
170+
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
171+
if not convs:
172+
return get_result(data=[])
220173
for conv in convs:
221174
conv['messages'] = conv.pop("message")
222-
conv["assistant_id"] = conv.pop("dialog_id")
175+
conv["chat"] = conv.pop("dialog_id")
223176
if conv["reference"]:
224177
messages = conv["messages"]
225178
message_num = 0
@@ -247,20 +200,19 @@ def list(tenant_id):
247200
messages[message_num]["reference"] = chunk_list
248201
message_num += 1
249202
del conv["reference"]
250-
return get_json_result(data=convs)
203+
return get_result(data=convs)
251204

252-
253-
@manager.route('/delete', methods=["DELETE"])
205+
@manager.route('/chat/<chat_id>/session', methods=["DELETE"])
254206
@token_required
255-
def delete(tenant_id):
256-
id = request.args.get("id")
257-
if not id:
258-
return get_data_error_result(retmsg="`id` is required in deleting operation")
259-
conv = ConversationService.query(id=id)
260-
if not conv:
261-
return get_data_error_result(retmsg="Session doesn't exist")
262-
conv = conv[0]
263-
if not DialogService.query(id=conv.dialog_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
264-
return get_data_error_result(retmsg="You don't own the session")
265-
ConversationService.delete_by_id(id)
266-
return get_json_result(data=True)
207+
def delete(tenant_id,chat_id):
208+
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
209+
return get_error_data_result(retmsg="You don't own the chat")
210+
ids = request.json.get("ids")
211+
if not ids:
212+
return get_error_data_result(retmsg="`ids` is required in deleting operation")
213+
for id in ids:
214+
conv = ConversationService.query(id=id,dialog_id=chat_id)
215+
if not conv:
216+
return get_error_data_result(retmsg="The chat doesn't own the session")
217+
ConversationService.delete_by_id(id)
218+
return get_result()

api/db/services/dialog_service.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import re
2020
from copy import deepcopy
2121
from timeit import default_timer as timer
22+
23+
2224
from api.db import LLMType, ParserType,StatusEnum
2325
from api.db.db_models import Dialog, Conversation,DB
2426
from api.db.services.common_service import CommonService
@@ -61,6 +63,22 @@ def get_list(cls, tenant_id,
6163
class ConversationService(CommonService):
6264
model = Conversation
6365

66+
@classmethod
67+
@DB.connection_context()
68+
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name):
69+
sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id)
70+
if id:
71+
sessions = sessions.where(cls.model.id == id)
72+
if name:
73+
sessions = sessions.where(cls.model.name == name)
74+
if desc:
75+
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
76+
else:
77+
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
78+
79+
sessions = sessions.paginate(page_number, items_per_page)
80+
81+
return list(sessions.dicts())
6482

6583
def message_fit_in(msg, max_length=4000):
6684
def count():

api/db/services/document_service.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,29 @@
4949
class DocumentService(CommonService):
5050
model = Document
5151

52+
@classmethod
53+
@DB.connection_context()
54+
def get_list(cls, kb_id, page_number, items_per_page,
55+
orderby, desc, keywords, id):
56+
docs =cls.model.select().where(cls.model.kb_id==kb_id)
57+
if id:
58+
docs = docs.where(
59+
cls.model.id== id )
60+
if keywords:
61+
docs = docs.where(
62+
fn.LOWER(cls.model.name).contains(keywords.lower())
63+
)
64+
count = docs.count()
65+
if desc:
66+
docs = docs.order_by(cls.model.getter_by(orderby).desc())
67+
else:
68+
docs = docs.order_by(cls.model.getter_by(orderby).asc())
69+
70+
docs = docs.paginate(page_number, items_per_page)
71+
72+
return list(docs.dicts()), count
73+
74+
5275
@classmethod
5376
@DB.connection_context()
5477
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
@@ -268,7 +291,7 @@ def get_doc_id_by_doc_name(cls, doc_name):
268291
@classmethod
269292
@DB.connection_context()
270293
def get_thumbnails(cls, docids):
271-
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
294+
fields = [cls.model.id, cls.model.thumbnail]
272295
return list(cls.model.select(
273296
*fields).where(cls.model.id.in_(docids)).dicts())
274297

0 commit comments

Comments
 (0)