Skip to content

Commit 16472eb

Browse files
solve knowledgegraph issue when calling gemini model (infiniflow#2738)
### What problem does this PR solve? infiniflow#2720 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
1 parent d92acdc commit 16472eb

File tree

1 file changed

+64
-62
lines changed

1 file changed

+64
-62
lines changed

rag/llm/chat_model.py

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from rag.nlp import is_english
2424
from rag.utils import num_tokens_from_string
2525
from groq import Groq
26-
import os
26+
import os
2727
import json
2828
import requests
2929
import asyncio
@@ -62,17 +62,17 @@ def chat_streamly(self, system, history, gen_conf):
6262
stream=True,
6363
**gen_conf)
6464
for resp in response:
65-
if not resp.choices:continue
65+
if not resp.choices: continue
6666
if not resp.choices[0].delta.content:
67-
resp.choices[0].delta.content = ""
67+
resp.choices[0].delta.content = ""
6868
ans += resp.choices[0].delta.content
6969
total_tokens = (
7070
(
71-
total_tokens
72-
+ num_tokens_from_string(resp.choices[0].delta.content)
71+
total_tokens
72+
+ num_tokens_from_string(resp.choices[0].delta.content)
7373
)
7474
if not hasattr(resp, "usage") or not resp.usage
75-
else resp.usage.get("total_tokens",total_tokens)
75+
else resp.usage.get("total_tokens", total_tokens)
7676
)
7777
if resp.choices[0].finish_reason == "length":
7878
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
@@ -87,13 +87,13 @@ def chat_streamly(self, system, history, gen_conf):
8787

8888
class GptTurbo(Base):
8989
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
90-
if not base_url: base_url="https://api.openai.com/v1"
90+
if not base_url: base_url = "https://api.openai.com/v1"
9191
super().__init__(key, model_name, base_url)
9292

9393

9494
class MoonshotChat(Base):
9595
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
96-
if not base_url: base_url="https://api.moonshot.cn/v1"
96+
if not base_url: base_url = "https://api.moonshot.cn/v1"
9797
super().__init__(key, model_name, base_url)
9898

9999

@@ -108,7 +108,7 @@ def __init__(self, key=None, model_name="", base_url=""):
108108

109109
class DeepSeekChat(Base):
110110
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
111-
if not base_url: base_url="https://api.deepseek.com/v1"
111+
if not base_url: base_url = "https://api.deepseek.com/v1"
112112
super().__init__(key, model_name, base_url)
113113

114114

@@ -178,14 +178,14 @@ def chat_streamly(self, system, history, gen_conf):
178178
stream=True,
179179
**self._format_params(gen_conf))
180180
for resp in response:
181-
if not resp.choices:continue
181+
if not resp.choices: continue
182182
if not resp.choices[0].delta.content:
183-
resp.choices[0].delta.content = ""
183+
resp.choices[0].delta.content = ""
184184
ans += resp.choices[0].delta.content
185185
total_tokens = (
186186
(
187-
total_tokens
188-
+ num_tokens_from_string(resp.choices[0].delta.content)
187+
total_tokens
188+
+ num_tokens_from_string(resp.choices[0].delta.content)
189189
)
190190
if not hasattr(resp, "usage")
191191
else resp.usage["total_tokens"]
@@ -252,7 +252,8 @@ def chat_streamly(self, system, history, gen_conf):
252252
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
253253
yield ans
254254
else:
255-
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
255+
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
256+
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
256257
except Exception as e:
257258
yield ans + "\n**ERROR**: " + str(e)
258259

@@ -298,7 +299,7 @@ def chat_streamly(self, system, history, gen_conf):
298299
**gen_conf
299300
)
300301
for resp in response:
301-
if not resp.choices[0].delta.content:continue
302+
if not resp.choices[0].delta.content: continue
302303
delta = resp.choices[0].delta.content
303304
ans += delta
304305
if resp.choices[0].finish_reason == "length":
@@ -411,15 +412,15 @@ def __init__(self, key, model_name):
411412
self.client = Client(port=12345, protocol="grpc", asyncio=True)
412413

413414
def _prepare_prompt(self, system, history, gen_conf):
414-
from rag.svr.jina_server import Prompt,Generation
415+
from rag.svr.jina_server import Prompt, Generation
415416
if system:
416417
history.insert(0, {"role": "system", "content": system})
417418
if "max_tokens" in gen_conf:
418419
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
419420
return Prompt(message=history, gen_conf=gen_conf)
420421

421422
def _stream_response(self, endpoint, prompt):
422-
from rag.svr.jina_server import Prompt,Generation
423+
from rag.svr.jina_server import Prompt, Generation
423424
answer = ""
424425
try:
425426
res = self.client.stream_doc(
@@ -463,10 +464,10 @@ def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/
463464

464465
class MiniMaxChat(Base):
465466
def __init__(
466-
self,
467-
key,
468-
model_name,
469-
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
467+
self,
468+
key,
469+
model_name,
470+
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
470471
):
471472
if not base_url:
472473
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
@@ -583,7 +584,7 @@ def chat_streamly(self, system, history, gen_conf):
583584
messages=history,
584585
**gen_conf)
585586
for resp in response:
586-
if not resp.choices or not resp.choices[0].delta.content:continue
587+
if not resp.choices or not resp.choices[0].delta.content: continue
587588
ans += resp.choices[0].delta.content
588589
total_tokens += 1
589590
if resp.choices[0].finish_reason == "length":
@@ -620,19 +621,18 @@ def chat(self, system, history, gen_conf):
620621
gen_conf["topP"] = gen_conf["top_p"]
621622
_ = gen_conf.pop("top_p")
622623
for item in history:
623-
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
624-
item["content"] = [{"text":item["content"]}]
625-
624+
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
625+
item["content"] = [{"text": item["content"]}]
626626

627627
try:
628628
# Send the message to the model, using a basic inference configuration.
629629
response = self.client.converse(
630630
modelId=self.model_name,
631631
messages=history,
632632
inferenceConfig=gen_conf,
633-
system=[{"text": (system if system else "Answer the user's message.")}] ,
633+
system=[{"text": (system if system else "Answer the user's message.")}],
634634
)
635-
635+
636636
# Extract and print the response text.
637637
ans = response["output"]["message"]["content"][0]["text"]
638638
return ans, num_tokens_from_string(ans)
@@ -652,9 +652,9 @@ def chat_streamly(self, system, history, gen_conf):
652652
gen_conf["topP"] = gen_conf["top_p"]
653653
_ = gen_conf.pop("top_p")
654654
for item in history:
655-
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
656-
item["content"] = [{"text":item["content"]}]
657-
655+
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
656+
item["content"] = [{"text": item["content"]}]
657+
658658
if self.model_name.split('.')[0] == 'ai21':
659659
try:
660660
response = self.client.converse(
@@ -684,7 +684,7 @@ def chat_streamly(self, system, history, gen_conf):
684684
if "contentBlockDelta" in resp:
685685
ans += resp["contentBlockDelta"]["delta"]["text"]
686686
yield ans
687-
687+
688688
except (ClientError, Exception) as e:
689689
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
690690

@@ -693,22 +693,21 @@ def chat_streamly(self, system, history, gen_conf):
693693

694694
class GeminiChat(Base):
695695

696-
def __init__(self, key, model_name,base_url=None):
697-
from google.generativeai import client,GenerativeModel
698-
696+
def __init__(self, key, model_name, base_url=None):
697+
from google.generativeai import client, GenerativeModel
698+
699699
client.configure(api_key=key)
700700
_client = client.get_default_generative_client()
701701
self.model_name = 'models/' + model_name
702702
self.model = GenerativeModel(model_name=self.model_name)
703703
self.model._client = _client
704-
705-
706-
def chat(self,system,history,gen_conf):
704+
705+
def chat(self, system, history, gen_conf):
707706
from google.generativeai.types import content_types
708-
707+
709708
if system:
710709
self.model._system_instruction = content_types.to_content(system)
711-
710+
712711
if 'max_tokens' in gen_conf:
713712
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
714713
for k in list(gen_conf.keys()):
@@ -717,9 +716,11 @@ def chat(self,system,history,gen_conf):
717716
for item in history:
718717
if 'role' in item and item['role'] == 'assistant':
719718
item['role'] = 'model'
720-
if 'content' in item :
719+
if 'role' in item and item['role'] == 'system':
720+
item['role'] = 'user'
721+
if 'content' in item:
721722
item['parts'] = item.pop('content')
722-
723+
723724
try:
724725
response = self.model.generate_content(
725726
history,
@@ -731,7 +732,7 @@ def chat(self,system,history,gen_conf):
731732

732733
def chat_streamly(self, system, history, gen_conf):
733734
from google.generativeai.types import content_types
734-
735+
735736
if system:
736737
self.model._system_instruction = content_types.to_content(system)
737738
if 'max_tokens' in gen_conf:
@@ -742,25 +743,25 @@ def chat_streamly(self, system, history, gen_conf):
742743
for item in history:
743744
if 'role' in item and item['role'] == 'assistant':
744745
item['role'] = 'model'
745-
if 'content' in item :
746+
if 'content' in item:
746747
item['parts'] = item.pop('content')
747748
ans = ""
748749
try:
749750
response = self.model.generate_content(
750751
history,
751-
generation_config=gen_conf,stream=True)
752+
generation_config=gen_conf, stream=True)
752753
for resp in response:
753754
ans += resp.text
754755
yield ans
755756

756757
except Exception as e:
757758
yield ans + "\n**ERROR**: " + str(e)
758759

759-
yield response._chunks[-1].usage_metadata.total_token_count
760+
yield response._chunks[-1].usage_metadata.total_token_count
760761

761762

762763
class GroqChat:
763-
def __init__(self, key, model_name,base_url=''):
764+
def __init__(self, key, model_name, base_url=''):
764765
self.client = Groq(api_key=key)
765766
self.model_name = model_name
766767

@@ -942,7 +943,7 @@ def chat_streamly(self, system, history, gen_conf):
942943
class LeptonAIChat(Base):
943944
def __init__(self, key, model_name, base_url=None):
944945
if not base_url:
945-
base_url = os.path.join("https://"+model_name+".lepton.run","api","v1")
946+
base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
946947
super().__init__(key, model_name, base_url)
947948

948949

@@ -1058,7 +1059,7 @@ def chat(self, system, history, gen_conf):
10581059
)
10591060

10601061
_gen_conf = {}
1061-
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
1062+
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
10621063
if system:
10631064
_history.insert(0, {"Role": "system", "Content": system})
10641065
if "temperature" in gen_conf:
@@ -1084,7 +1085,7 @@ def chat_streamly(self, system, history, gen_conf):
10841085
)
10851086

10861087
_gen_conf = {}
1087-
_history = [{k.capitalize(): v for k, v in item.items() } for item in history]
1088+
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
10881089
if system:
10891090
_history.insert(0, {"Role": "system", "Content": system})
10901091

@@ -1121,7 +1122,7 @@ def chat_streamly(self, system, history, gen_conf):
11211122

11221123
class SparkChat(Base):
11231124
def __init__(
1124-
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
1125+
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
11251126
):
11261127
if not base_url:
11271128
base_url = "https://spark-api-open.xf-yun.com/v1"
@@ -1141,26 +1142,27 @@ def __init__(self, key, model_name, base_url=None):
11411142
import qianfan
11421143

11431144
key = json.loads(key)
1144-
ak = key.get("yiyan_ak","")
1145-
sk = key.get("yiyan_sk","")
1146-
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
1145+
ak = key.get("yiyan_ak", "")
1146+
sk = key.get("yiyan_sk", "")
1147+
self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
11471148
self.model_name = model_name.lower()
11481149
self.system = ""
11491150

11501151
def chat(self, system, history, gen_conf):
11511152
if system:
11521153
self.system = system
11531154
gen_conf["penalty_score"] = (
1154-
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
1155-
) + 1
1155+
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
1156+
0)) / 2
1157+
) + 1
11561158
if "max_tokens" in gen_conf:
11571159
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
11581160
ans = ""
11591161

11601162
try:
11611163
response = self.client.do(
1162-
model=self.model_name,
1163-
messages=history,
1164+
model=self.model_name,
1165+
messages=history,
11641166
system=self.system,
11651167
**gen_conf
11661168
).body
@@ -1174,17 +1176,18 @@ def chat_streamly(self, system, history, gen_conf):
11741176
if system:
11751177
self.system = system
11761178
gen_conf["penalty_score"] = (
1177-
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
1178-
) + 1
1179+
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
1180+
0)) / 2
1181+
) + 1
11791182
if "max_tokens" in gen_conf:
11801183
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
11811184
ans = ""
11821185
total_tokens = 0
11831186

11841187
try:
11851188
response = self.client.do(
1186-
model=self.model_name,
1187-
messages=history,
1189+
model=self.model_name,
1190+
messages=history,
11881191
system=self.system,
11891192
stream=True,
11901193
**gen_conf
@@ -1415,4 +1418,3 @@ def chat_streamly(self, system, history, gen_conf):
14151418
yield ans + "\n**ERROR**: " + str(e)
14161419

14171420
yield response._chunks[-1].usage_metadata.total_token_count
1418-

0 commit comments

Comments
 (0)