2
2
import re
3
3
import asyncio
4
4
import datetime
5
-
6
5
from typing import Awaitable
7
6
from langchain import LLMChain
8
7
from langchain .schema import Document
9
8
from langchain .callbacks import AsyncIteratorCallbackHandler
10
9
11
10
import db_services as _dbs_
12
-
13
- from utils .ebbinghaus import handle_ebbinghaus_memory
14
11
from .langchain_llm import LLM
12
+ from utils .ebbinghaus import handle_ebbinghaus_memory
15
13
from prompts import choose_prompt
16
14
17
15
16
+ # Main Chain class, which encapsulates LLMChain configurations and various methods.
18
17
class Chain :
19
18
def __init__ (
20
19
self ,
@@ -37,6 +36,7 @@ def __init__(
37
36
self .streaming = streaming
38
37
self .llm_callbacks = [AsyncIteratorCallbackHandler ()]
39
38
39
+ # Asynchronous method to generate questions.
40
40
async def agenerate_questions (
41
41
self ,
42
42
docs : list [Document ],
@@ -57,6 +57,7 @@ async def agenerate_questions(
57
57
_dbs_ .file .set_file_is_uploading_state (self .file_id )
58
58
raise e
59
59
60
+ # Helper method for specific question generation.
60
61
async def _agenerate_questions (
61
62
self ,
62
63
llm_chain : LLMChain ,
@@ -79,6 +80,7 @@ async def _agenerate_questions(
79
80
question_type = question_type
80
81
)
81
82
83
+ # Method to check answers.
82
84
async def aexamine_answer (
83
85
self ,
84
86
quesiton_id : int ,
@@ -139,6 +141,7 @@ def _init_llm_chain(self, timeout: int, role: str, question_type: str):
139
141
)
140
142
141
143
144
+ # Helper function to wait for asynchronous tasks to complete.
142
145
async def _wait_done (
143
146
fn : Awaitable ,
144
147
event : asyncio .Event
@@ -152,6 +155,7 @@ async def _wait_done(
152
155
event .set ()
153
156
154
157
158
+ # Adjust concurrency based on payment status.
155
159
def _adjust_concurrency_by_payment_status ():
156
160
payment = os .environ .get ("PAYMENT" , "free" )
157
161
if (payment == "free" ):
@@ -160,6 +164,7 @@ def _adjust_concurrency_by_payment_status():
160
164
return 3
161
165
162
166
167
+ # Adjust the number of retries based on payment status.
163
168
def _adjust_retries_by_payment_status ():
164
169
payment = os .environ .get ("PAYMENT" , "free" )
165
170
if (payment == "free" ):
@@ -168,6 +173,7 @@ def _adjust_retries_by_payment_status():
168
173
return 6
169
174
170
175
176
+ # Split the generated questions.
171
177
def _spite_questions (
172
178
content : str ,
173
179
type : str
@@ -180,6 +186,7 @@ def _spite_questions(
180
186
return questions
181
187
182
188
189
+ # Check if the structure of the question is valid.
183
190
def _is_legal_question_structure (
184
191
content : str ,
185
192
type : str
@@ -195,11 +202,13 @@ def _is_legal_question_structure(
195
202
return True
196
203
197
204
205
+ # Remove prefix numbers or dashes from a question.
198
206
def _remove_prefix_numbers (text ):
199
207
cleaned_text = re .sub (r'^\s*(?:\d+\.|-)\s*' , '' , text )
200
208
return cleaned_text .strip ()
201
209
202
210
211
+ # Extract score from the answer.
203
212
def _extract_score (anwser : str ):
204
213
score = re .findall (r"\d+\.?\d*" , anwser )
205
214
if score :
@@ -208,6 +217,7 @@ def _extract_score(anwser: str):
208
217
return 0
209
218
210
219
220
+ # Get the push date based on the score.
211
221
def _get_push_date (score : int ):
212
222
now = datetime .datetime .now ()
213
223
days = handle_ebbinghaus_memory (score )
0 commit comments