Skip to content

Commit f9865e3

Browse files
committed
refactor-server: extract code from db_service
1 parent a411508 commit f9865e3

File tree

2 files changed

+35
-71
lines changed

2 files changed

+35
-71
lines changed

server/db_services/question.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import re
21
import os
32
import datetime
43

@@ -36,15 +35,14 @@ def save_question_to_db(
3635
async def update_question_state(
3736
id: int,
3837
answer: str,
38+
score: str,
39+
push_date: str
3940
):
4041
query = """
4142
UPDATE t_question
4243
SET last_answer = %s, progress = progress + %s, is_answered_today = %s, push_date = %s
4344
WHERE id = %s;
4445
"""
45-
chunks = answer.split("|||")
46-
score = _extract_score(chunks[1])
47-
push_date = get_push_date(score)
4846
data = (answer, score, "1", push_date, id, )
4947
MySQLHandler().update_table_data(query, data)
5048

@@ -60,30 +58,6 @@ def get_random_question_info():
6058
return MySQLHandler().execute_query(query, single=True)
6159

6260

63-
def _extract_score(anwser: str):
64-
score = re.findall(r"\d+\.?\d*", anwser)
65-
if score:
66-
return int(float(score[0]))
67-
else:
68-
return 0
69-
70-
71-
def get_push_date(score: int):
72-
now = datetime.datetime.now()
73-
days = 0
74-
75-
if (0 <= score <= 3):
76-
days = 1
77-
if (4 <= score <= 6):
78-
days = 3
79-
if (7 <= score <= 9):
80-
days = 7
81-
if (10 == score):
82-
days = 14
83-
84-
return ((now+datetime.timedelta(days)).strftime("%Y-%m-%d"))
85-
86-
8761
def get_expired_questions(note_id: int):
8862
now_date = datetime.date.today().strftime('%Y-%m-%d')
8963
query = """

server/llm_services/langchain_chain.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import re
33
import asyncio
4+
import datetime
45

56
from typing import Awaitable
67
from langchain import LLMChain
@@ -24,17 +25,6 @@ def __init__(
2425
temperature: int = 0,
2526
streaming: bool = False
2627
):
27-
"""
28-
Initializes a Chain instance.
29-
30-
:param note_id: The ID of the note.
31-
:param file_id: The ID of the file.
32-
:param filename: The filename.
33-
:param prompt_language: The language for prompts.
34-
:param prompt_type: The type of prompt.
35-
:param temperature: The temperature for language model.
36-
:param streaming: Whether streaming is enabled.
37-
"""
3828
self.semaphore = asyncio.Semaphore(
3929
_adjust_concurrency_by_payment_status())
4030
self.note_id = note_id
@@ -52,12 +42,6 @@ async def agenerate_questions(
5242
title: str,
5343
question_type: str
5444
):
55-
"""
56-
Generates questions for documents.
57-
58-
:param docs: List of Document objects.
59-
:param title: The title of the questions.
60-
"""
6145
tasks = []
6246
llm_chain = self._init_llm_chain(60, "", question_type)
6347
for doc in docs:
@@ -80,14 +64,6 @@ async def _agenerate_questions(
8064
doc_id: int,
8165
question_type: str
8266
):
83-
"""
84-
Generates questions for a document using the language model.
85-
86-
:param llm_chain: The LLMChain instance.
87-
:param doc: The Document object.
88-
:param title: The title for questions.
89-
:param doc_id: The ID of the document.
90-
"""
9167
async with self.semaphore:
9268
res = await llm_chain.apredict(
9369
title=title,
@@ -111,16 +87,6 @@ async def aexamine_answer(
11187
role: str,
11288
question_type: str,
11389
):
114-
"""
115-
Examines an answer using the language model.
116-
117-
:param id: The ID of the question.
118-
:param context: The context for examination.
119-
:param question: The question for examination.
120-
:param answer: The answer for examination.
121-
:param role: The role for the examination.
122-
:yield: The examination results.
123-
"""
12490
llm_chain = self._init_llm_chain(60, role, question_type)
12591
coroutine = wait_done(llm_chain.apredict(
12692
context=context,
@@ -141,16 +107,16 @@ async def aexamine_answer(
141107
yield str(e)
142108
return
143109

144-
await _dbs_.question.update_question_state(id, f"{answer} ||| {exmine}")
110+
score = _extract_score(exmine)
111+
push_date = _get_push_date(score)
112+
await _dbs_.question.update_question_state(
113+
id=id,
114+
answer=f"{answer} ||| {exmine}",
115+
score=score,
116+
push_date=push_date
117+
)
145118

146119
def _init_llm_chain(self, timeout: int, role: str, question_type: str):
147-
"""
148-
Initializes the language model chain.
149-
150-
:param timeout: The timeout value.
151-
:param role: The role for the examination.
152-
:return: The initialized LLMChain instance.
153-
"""
154120
llm_instance = LLM(
155121
temperature=self.temperature,
156122
streaming=self.streaming,
@@ -231,3 +197,27 @@ def _is_legal_question_structure(
231197
def _remove_prefix_numbers(text):
232198
cleaned_text = re.sub(r'^\s*(?:\d+\.|-)\s*', '', text)
233199
return cleaned_text.strip()
200+
201+
202+
def _extract_score(anwser: str):
203+
score = re.findall(r"\d+\.?\d*", anwser)
204+
if score:
205+
return int(float(score[0]))
206+
else:
207+
return 0
208+
209+
210+
def _get_push_date(score: int):
211+
now = datetime.datetime.now()
212+
days = 0
213+
214+
if (0 <= score <= 3):
215+
days = 1
216+
if (4 <= score <= 6):
217+
days = 3
218+
if (7 <= score <= 9):
219+
days = 7
220+
if (10 == score):
221+
days = 14
222+
223+
return ((now+datetime.timedelta(days)).strftime("%Y-%m-%d"))

0 commit comments

Comments
 (0)