23
23
from rag .nlp import is_english
24
24
from rag .utils import num_tokens_from_string
25
25
from groq import Groq
26
- import os
26
+ import os
27
27
import json
28
28
import requests
29
29
import asyncio
@@ -62,17 +62,17 @@ def chat_streamly(self, system, history, gen_conf):
62
62
stream = True ,
63
63
** gen_conf )
64
64
for resp in response :
65
- if not resp .choices :continue
65
+ if not resp .choices : continue
66
66
if not resp .choices [0 ].delta .content :
67
- resp .choices [0 ].delta .content = ""
67
+ resp .choices [0 ].delta .content = ""
68
68
ans += resp .choices [0 ].delta .content
69
69
total_tokens = (
70
70
(
71
- total_tokens
72
- + num_tokens_from_string (resp .choices [0 ].delta .content )
71
+ total_tokens
72
+ + num_tokens_from_string (resp .choices [0 ].delta .content )
73
73
)
74
74
if not hasattr (resp , "usage" ) or not resp .usage
75
- else resp .usage .get ("total_tokens" ,total_tokens )
75
+ else resp .usage .get ("total_tokens" , total_tokens )
76
76
)
77
77
if resp .choices [0 ].finish_reason == "length" :
78
78
ans += "...\n For the content length reason, it stopped, continue?" if is_english (
@@ -87,13 +87,13 @@ def chat_streamly(self, system, history, gen_conf):
87
87
88
88
class GptTurbo (Base ):
89
89
def __init__ (self , key , model_name = "gpt-3.5-turbo" , base_url = "https://api.openai.com/v1" ):
90
- if not base_url : base_url = "https://api.openai.com/v1"
90
+ if not base_url : base_url = "https://api.openai.com/v1"
91
91
super ().__init__ (key , model_name , base_url )
92
92
93
93
94
94
class MoonshotChat (Base ):
95
95
def __init__ (self , key , model_name = "moonshot-v1-8k" , base_url = "https://api.moonshot.cn/v1" ):
96
- if not base_url : base_url = "https://api.moonshot.cn/v1"
96
+ if not base_url : base_url = "https://api.moonshot.cn/v1"
97
97
super ().__init__ (key , model_name , base_url )
98
98
99
99
@@ -108,7 +108,7 @@ def __init__(self, key=None, model_name="", base_url=""):
108
108
109
109
class DeepSeekChat (Base ):
110
110
def __init__ (self , key , model_name = "deepseek-chat" , base_url = "https://api.deepseek.com/v1" ):
111
- if not base_url : base_url = "https://api.deepseek.com/v1"
111
+ if not base_url : base_url = "https://api.deepseek.com/v1"
112
112
super ().__init__ (key , model_name , base_url )
113
113
114
114
@@ -178,14 +178,14 @@ def chat_streamly(self, system, history, gen_conf):
178
178
stream = True ,
179
179
** self ._format_params (gen_conf ))
180
180
for resp in response :
181
- if not resp .choices :continue
181
+ if not resp .choices : continue
182
182
if not resp .choices [0 ].delta .content :
183
- resp .choices [0 ].delta .content = ""
183
+ resp .choices [0 ].delta .content = ""
184
184
ans += resp .choices [0 ].delta .content
185
185
total_tokens = (
186
186
(
187
- total_tokens
188
- + num_tokens_from_string (resp .choices [0 ].delta .content )
187
+ total_tokens
188
+ + num_tokens_from_string (resp .choices [0 ].delta .content )
189
189
)
190
190
if not hasattr (resp , "usage" )
191
191
else resp .usage ["total_tokens" ]
@@ -252,7 +252,8 @@ def chat_streamly(self, system, history, gen_conf):
252
252
[ans ]) else "······\n 由于长度的原因,回答被截断了,要继续吗?"
253
253
yield ans
254
254
else :
255
- yield ans + "\n **ERROR**: " + resp .message if str (resp .message ).find ("Access" )< 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
255
+ yield ans + "\n **ERROR**: " + resp .message if str (resp .message ).find (
256
+ "Access" ) < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
256
257
except Exception as e :
257
258
yield ans + "\n **ERROR**: " + str (e )
258
259
@@ -298,7 +299,7 @@ def chat_streamly(self, system, history, gen_conf):
298
299
** gen_conf
299
300
)
300
301
for resp in response :
301
- if not resp .choices [0 ].delta .content :continue
302
+ if not resp .choices [0 ].delta .content : continue
302
303
delta = resp .choices [0 ].delta .content
303
304
ans += delta
304
305
if resp .choices [0 ].finish_reason == "length" :
@@ -411,15 +412,15 @@ def __init__(self, key, model_name):
411
412
self .client = Client (port = 12345 , protocol = "grpc" , asyncio = True )
412
413
413
414
def _prepare_prompt (self , system , history , gen_conf ):
414
- from rag .svr .jina_server import Prompt ,Generation
415
+ from rag .svr .jina_server import Prompt , Generation
415
416
if system :
416
417
history .insert (0 , {"role" : "system" , "content" : system })
417
418
if "max_tokens" in gen_conf :
418
419
gen_conf ["max_new_tokens" ] = gen_conf .pop ("max_tokens" )
419
420
return Prompt (message = history , gen_conf = gen_conf )
420
421
421
422
def _stream_response (self , endpoint , prompt ):
422
- from rag .svr .jina_server import Prompt ,Generation
423
+ from rag .svr .jina_server import Prompt , Generation
423
424
answer = ""
424
425
try :
425
426
res = self .client .stream_doc (
@@ -463,10 +464,10 @@ def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/
463
464
464
465
class MiniMaxChat (Base ):
465
466
def __init__ (
466
- self ,
467
- key ,
468
- model_name ,
469
- base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" ,
467
+ self ,
468
+ key ,
469
+ model_name ,
470
+ base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" ,
470
471
):
471
472
if not base_url :
472
473
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
@@ -583,7 +584,7 @@ def chat_streamly(self, system, history, gen_conf):
583
584
messages = history ,
584
585
** gen_conf )
585
586
for resp in response :
586
- if not resp .choices or not resp .choices [0 ].delta .content :continue
587
+ if not resp .choices or not resp .choices [0 ].delta .content : continue
587
588
ans += resp .choices [0 ].delta .content
588
589
total_tokens += 1
589
590
if resp .choices [0 ].finish_reason == "length" :
@@ -620,19 +621,18 @@ def chat(self, system, history, gen_conf):
620
621
gen_conf ["topP" ] = gen_conf ["top_p" ]
621
622
_ = gen_conf .pop ("top_p" )
622
623
for item in history :
623
- if not isinstance (item ["content" ],list ) and not isinstance (item ["content" ],tuple ):
624
- item ["content" ] = [{"text" :item ["content" ]}]
625
-
624
+ if not isinstance (item ["content" ], list ) and not isinstance (item ["content" ], tuple ):
625
+ item ["content" ] = [{"text" : item ["content" ]}]
626
626
627
627
try :
628
628
# Send the message to the model, using a basic inference configuration.
629
629
response = self .client .converse (
630
630
modelId = self .model_name ,
631
631
messages = history ,
632
632
inferenceConfig = gen_conf ,
633
- system = [{"text" : (system if system else "Answer the user's message." )}] ,
633
+ system = [{"text" : (system if system else "Answer the user's message." )}],
634
634
)
635
-
635
+
636
636
# Extract and print the response text.
637
637
ans = response ["output" ]["message" ]["content" ][0 ]["text" ]
638
638
return ans , num_tokens_from_string (ans )
@@ -652,9 +652,9 @@ def chat_streamly(self, system, history, gen_conf):
652
652
gen_conf ["topP" ] = gen_conf ["top_p" ]
653
653
_ = gen_conf .pop ("top_p" )
654
654
for item in history :
655
- if not isinstance (item ["content" ],list ) and not isinstance (item ["content" ],tuple ):
656
- item ["content" ] = [{"text" :item ["content" ]}]
657
-
655
+ if not isinstance (item ["content" ], list ) and not isinstance (item ["content" ], tuple ):
656
+ item ["content" ] = [{"text" : item ["content" ]}]
657
+
658
658
if self .model_name .split ('.' )[0 ] == 'ai21' :
659
659
try :
660
660
response = self .client .converse (
@@ -684,7 +684,7 @@ def chat_streamly(self, system, history, gen_conf):
684
684
if "contentBlockDelta" in resp :
685
685
ans += resp ["contentBlockDelta" ]["delta" ]["text" ]
686
686
yield ans
687
-
687
+
688
688
except (ClientError , Exception ) as e :
689
689
yield ans + f"ERROR: Can't invoke '{ self .model_name } '. Reason: { e } "
690
690
@@ -693,22 +693,21 @@ def chat_streamly(self, system, history, gen_conf):
693
693
694
694
class GeminiChat (Base ):
695
695
696
- def __init__ (self , key , model_name ,base_url = None ):
697
- from google .generativeai import client ,GenerativeModel
698
-
696
+ def __init__ (self , key , model_name , base_url = None ):
697
+ from google .generativeai import client , GenerativeModel
698
+
699
699
client .configure (api_key = key )
700
700
_client = client .get_default_generative_client ()
701
701
self .model_name = 'models/' + model_name
702
702
self .model = GenerativeModel (model_name = self .model_name )
703
703
self .model ._client = _client
704
-
705
-
706
- def chat (self ,system ,history ,gen_conf ):
704
+
705
+ def chat (self , system , history , gen_conf ):
707
706
from google .generativeai .types import content_types
708
-
707
+
709
708
if system :
710
709
self .model ._system_instruction = content_types .to_content (system )
711
-
710
+
712
711
if 'max_tokens' in gen_conf :
713
712
gen_conf ['max_output_tokens' ] = gen_conf ['max_tokens' ]
714
713
for k in list (gen_conf .keys ()):
@@ -717,9 +716,11 @@ def chat(self,system,history,gen_conf):
717
716
for item in history :
718
717
if 'role' in item and item ['role' ] == 'assistant' :
719
718
item ['role' ] = 'model'
720
- if 'content' in item :
719
+ if 'role' in item and item ['role' ] == 'system' :
720
+ item ['role' ] = 'user'
721
+ if 'content' in item :
721
722
item ['parts' ] = item .pop ('content' )
722
-
723
+
723
724
try :
724
725
response = self .model .generate_content (
725
726
history ,
@@ -731,7 +732,7 @@ def chat(self,system,history,gen_conf):
731
732
732
733
def chat_streamly (self , system , history , gen_conf ):
733
734
from google .generativeai .types import content_types
734
-
735
+
735
736
if system :
736
737
self .model ._system_instruction = content_types .to_content (system )
737
738
if 'max_tokens' in gen_conf :
@@ -742,25 +743,25 @@ def chat_streamly(self, system, history, gen_conf):
742
743
for item in history :
743
744
if 'role' in item and item ['role' ] == 'assistant' :
744
745
item ['role' ] = 'model'
745
- if 'content' in item :
746
+ if 'content' in item :
746
747
item ['parts' ] = item .pop ('content' )
747
748
ans = ""
748
749
try :
749
750
response = self .model .generate_content (
750
751
history ,
751
- generation_config = gen_conf ,stream = True )
752
+ generation_config = gen_conf , stream = True )
752
753
for resp in response :
753
754
ans += resp .text
754
755
yield ans
755
756
756
757
except Exception as e :
757
758
yield ans + "\n **ERROR**: " + str (e )
758
759
759
- yield response ._chunks [- 1 ].usage_metadata .total_token_count
760
+ yield response ._chunks [- 1 ].usage_metadata .total_token_count
760
761
761
762
762
763
class GroqChat :
763
- def __init__ (self , key , model_name ,base_url = '' ):
764
+ def __init__ (self , key , model_name , base_url = '' ):
764
765
self .client = Groq (api_key = key )
765
766
self .model_name = model_name
766
767
@@ -942,7 +943,7 @@ def chat_streamly(self, system, history, gen_conf):
942
943
class LeptonAIChat (Base ):
943
944
def __init__ (self , key , model_name , base_url = None ):
944
945
if not base_url :
945
- base_url = os .path .join ("https://" + model_name + ".lepton.run" ,"api" ,"v1" )
946
+ base_url = os .path .join ("https://" + model_name + ".lepton.run" , "api" , "v1" )
946
947
super ().__init__ (key , model_name , base_url )
947
948
948
949
@@ -1058,7 +1059,7 @@ def chat(self, system, history, gen_conf):
1058
1059
)
1059
1060
1060
1061
_gen_conf = {}
1061
- _history = [{k .capitalize (): v for k , v in item .items () } for item in history ]
1062
+ _history = [{k .capitalize (): v for k , v in item .items ()} for item in history ]
1062
1063
if system :
1063
1064
_history .insert (0 , {"Role" : "system" , "Content" : system })
1064
1065
if "temperature" in gen_conf :
@@ -1084,7 +1085,7 @@ def chat_streamly(self, system, history, gen_conf):
1084
1085
)
1085
1086
1086
1087
_gen_conf = {}
1087
- _history = [{k .capitalize (): v for k , v in item .items () } for item in history ]
1088
+ _history = [{k .capitalize (): v for k , v in item .items ()} for item in history ]
1088
1089
if system :
1089
1090
_history .insert (0 , {"Role" : "system" , "Content" : system })
1090
1091
@@ -1121,7 +1122,7 @@ def chat_streamly(self, system, history, gen_conf):
1121
1122
1122
1123
class SparkChat (Base ):
1123
1124
def __init__ (
1124
- self , key , model_name , base_url = "https://spark-api-open.xf-yun.com/v1"
1125
+ self , key , model_name , base_url = "https://spark-api-open.xf-yun.com/v1"
1125
1126
):
1126
1127
if not base_url :
1127
1128
base_url = "https://spark-api-open.xf-yun.com/v1"
@@ -1141,26 +1142,27 @@ def __init__(self, key, model_name, base_url=None):
1141
1142
import qianfan
1142
1143
1143
1144
key = json .loads (key )
1144
- ak = key .get ("yiyan_ak" ,"" )
1145
- sk = key .get ("yiyan_sk" ,"" )
1146
- self .client = qianfan .ChatCompletion (ak = ak ,sk = sk )
1145
+ ak = key .get ("yiyan_ak" , "" )
1146
+ sk = key .get ("yiyan_sk" , "" )
1147
+ self .client = qianfan .ChatCompletion (ak = ak , sk = sk )
1147
1148
self .model_name = model_name .lower ()
1148
1149
self .system = ""
1149
1150
1150
1151
def chat (self , system , history , gen_conf ):
1151
1152
if system :
1152
1153
self .system = system
1153
1154
gen_conf ["penalty_score" ] = (
1154
- (gen_conf .get ("presence_penalty" , 0 ) + gen_conf .get ("frequency_penalty" , 0 )) / 2
1155
- ) + 1
1155
+ (gen_conf .get ("presence_penalty" , 0 ) + gen_conf .get ("frequency_penalty" ,
1156
+ 0 )) / 2
1157
+ ) + 1
1156
1158
if "max_tokens" in gen_conf :
1157
1159
gen_conf ["max_output_tokens" ] = gen_conf ["max_tokens" ]
1158
1160
ans = ""
1159
1161
1160
1162
try :
1161
1163
response = self .client .do (
1162
- model = self .model_name ,
1163
- messages = history ,
1164
+ model = self .model_name ,
1165
+ messages = history ,
1164
1166
system = self .system ,
1165
1167
** gen_conf
1166
1168
).body
@@ -1174,17 +1176,18 @@ def chat_streamly(self, system, history, gen_conf):
1174
1176
if system :
1175
1177
self .system = system
1176
1178
gen_conf ["penalty_score" ] = (
1177
- (gen_conf .get ("presence_penalty" , 0 ) + gen_conf .get ("frequency_penalty" , 0 )) / 2
1178
- ) + 1
1179
+ (gen_conf .get ("presence_penalty" , 0 ) + gen_conf .get ("frequency_penalty" ,
1180
+ 0 )) / 2
1181
+ ) + 1
1179
1182
if "max_tokens" in gen_conf :
1180
1183
gen_conf ["max_output_tokens" ] = gen_conf ["max_tokens" ]
1181
1184
ans = ""
1182
1185
total_tokens = 0
1183
1186
1184
1187
try :
1185
1188
response = self .client .do (
1186
- model = self .model_name ,
1187
- messages = history ,
1189
+ model = self .model_name ,
1190
+ messages = history ,
1188
1191
system = self .system ,
1189
1192
stream = True ,
1190
1193
** gen_conf
@@ -1415,4 +1418,3 @@ def chat_streamly(self, system, history, gen_conf):
1415
1418
yield ans + "\n **ERROR**: " + str (e )
1416
1419
1417
1420
yield response ._chunks [- 1 ].usage_metadata .total_token_count
1418
-
0 commit comments