Skip to content

Commit 2bc6c42

Browse files
support api-version in adding azure-openai
1 parent 51efecf commit 2bc6c42

File tree

11 files changed

+202
-10
lines changed

11 files changed

+202
-10
lines changed

api/apps/llm_app.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def set_api_key():
5858
chat_passed, embd_passed, rerank_passed = False, False, False
5959
factory = req["llm_factory"]
6060
msg = ""
61-
for llm in LLMService.query(fid=factory)[:3]:
61+
for llm in LLMService.query(fid=factory):
6262
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
6363
mdl = EmbeddingModel[factory](
6464
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
@@ -77,21 +77,25 @@ def set_api_key():
7777
{"temperature": 0.9,'max_tokens':50})
7878
if m.find("**ERROR**") >=0:
7979
raise Exception(m)
80+
chat_passed = True
8081
except Exception as e:
8182
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
8283
e)
83-
chat_passed = True
8484
elif not rerank_passed and llm.model_type == LLMType.RERANK:
8585
mdl = RerankModel[factory](
8686
req["api_key"], llm.llm_name, base_url=req.get("base_url"))
8787
try:
8888
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
8989
if len(arr) == 0 or tc == 0:
9090
raise Exception("Fail")
91+
rerank_passed = True
92+
print(f'passed model rerank{llm.llm_name}',flush=True)
9193
except Exception as e:
9294
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(
9395
e)
94-
rerank_passed = True
96+
if any([embd_passed, chat_passed, rerank_passed]):
97+
msg = ''
98+
break
9599

96100
if msg:
97101
return get_data_error_result(retmsg=msg)
@@ -183,6 +187,9 @@ def apikey_json(keys):
183187
llm_name = req["llm_name"]
184188
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
185189

190+
elif factory == "Azure-OpenAI":
191+
llm_name = req["llm_name"]
192+
api_key = apikey_json(["api_key", "api_version"])
186193
else:
187194
llm_name = req["llm_name"]
188195
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")

conf/llm_factories.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,13 +619,13 @@
619619
"model_type": "chat,image2text"
620620
},
621621
{
622-
"llm_name": "gpt-35-turbo",
622+
"llm_name": "gpt-3.5-turbo",
623623
"tags": "LLM,CHAT,4K",
624624
"max_tokens": 4096,
625625
"model_type": "chat"
626626
},
627627
{
628-
"llm_name": "gpt-35-turbo-16k",
628+
"llm_name": "gpt-3.5-turbo-16k",
629629
"tags": "LLM,CHAT,16k",
630630
"max_tokens": 16385,
631631
"model_type": "chat"

rag/llm/chat_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepse
114114

115115
class AzureChat(Base):
116116
def __init__(self, key, model_name, **kwargs):
117-
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
117+
api_key = json.loads(key).get('api_key', '')
118+
api_version = json.loads(key).get('api_version', '2024-02-01')
119+
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
118120
self.model_name = model_name
119121

120122

rag/llm/cv_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def describe(self, image, max_tokens=300):
160160

161161
class AzureGptV4(Base):
162162
def __init__(self, key, model_name, lang="Chinese", **kwargs):
163-
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
163+
api_key = json.loads(key).get('api_key', '')
164+
api_version = json.loads(key).get('api_version', '2024-02-01')
165+
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
164166
self.model_name = model_name
165167
self.lang = lang
166168

rag/llm/embedding_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def encode_queries(self, text):
137137
class AzureEmbed(OpenAIEmbed):
138138
def __init__(self, key, model_name, **kwargs):
139139
from openai.lib.azure import AzureOpenAI
140-
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
140+
api_key = json.loads(key).get('api_key', '')
141+
api_version = json.loads(key).get('api_version', '2024-02-01')
142+
self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version)
141143
self.model_name = model_name
142144

143145

web/src/locales/en.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,8 @@ The above is the content you need to summarize.`,
581581
GoogleRegionMessage: 'Please input Google Cloud Region',
582582
modelProvidersWarn:
583583
'Please add both embedding model and LLM in <b>Settings > Model providers</b> firstly.',
584+
apiVersion: 'API-Version',
585+
apiVersionMessage: 'Please input API version',
584586
},
585587
message: {
586588
registered: 'Registered!',

web/src/locales/zh.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,8 @@ export default {
557557
GoogleRegionMessage: '请输入 Google Cloud 区域',
558558
modelProvidersWarn:
559559
'请首先在 <b>设置 > 模型提供商</b> 中添加嵌入模型和 LLM。',
560+
apiVersion: 'API版本',
561+
apiVersionMessage: '请输入API版本!',
560562
},
561563
message: {
562564
registered: '注册成功',
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import { useTranslate } from '@/hooks/common-hooks';
2+
import { IModalProps } from '@/interfaces/common';
3+
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
4+
import { Form, Input, Modal, Select, Switch } from 'antd';
5+
import omit from 'lodash/omit';
6+
7+
type FieldType = IAddLlmRequestBody & {
8+
api_version: string;
9+
vision: boolean;
10+
};
11+
12+
const { Option } = Select;
13+
14+
const AzureOpenAIModal = ({
15+
visible,
16+
hideModal,
17+
onOk,
18+
loading,
19+
llmFactory,
20+
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
21+
const [form] = Form.useForm<FieldType>();
22+
23+
const { t } = useTranslate('setting');
24+
25+
const handleOk = async () => {
26+
const values = await form.validateFields();
27+
const modelType =
28+
values.model_type === 'chat' && values.vision
29+
? 'image2text'
30+
: values.model_type;
31+
32+
const data = {
33+
...omit(values, ['vision']),
34+
model_type: modelType,
35+
llm_factory: llmFactory,
36+
};
37+
console.info(data);
38+
39+
onOk?.(data);
40+
};
41+
const optionsMap = {
42+
Default: [
43+
{ value: 'chat', label: 'chat' },
44+
{ value: 'embedding', label: 'embedding' },
45+
{ value: 'image2text', label: 'image2text' },
46+
],
47+
};
48+
const getOptions = (factory: string) => {
49+
return optionsMap.Default;
50+
};
51+
return (
52+
<Modal
53+
title={t('addLlmTitle', { name: llmFactory })}
54+
open={visible}
55+
onOk={handleOk}
56+
onCancel={hideModal}
57+
okButtonProps={{ loading }}
58+
>
59+
<Form
60+
name="basic"
61+
style={{ maxWidth: 600 }}
62+
autoComplete="off"
63+
layout={'vertical'}
64+
form={form}
65+
>
66+
<Form.Item<FieldType>
67+
label={t('modelType')}
68+
name="model_type"
69+
initialValue={'embedding'}
70+
rules={[{ required: true, message: t('modelTypeMessage') }]}
71+
>
72+
<Select placeholder={t('modelTypeMessage')}>
73+
{getOptions(llmFactory).map((option) => (
74+
<Option key={option.value} value={option.value}>
75+
{option.label}
76+
</Option>
77+
))}
78+
</Select>
79+
</Form.Item>
80+
<Form.Item<FieldType>
81+
label={t('addLlmBaseUrl')}
82+
name="api_base"
83+
rules={[{ required: true, message: t('baseUrlNameMessage') }]}
84+
>
85+
<Input placeholder={t('baseUrlNameMessage')} />
86+
</Form.Item>
87+
<Form.Item<FieldType>
88+
label={t('apiKey')}
89+
name="api_key"
90+
rules={[{ required: false, message: t('apiKeyMessage') }]}
91+
>
92+
<Input placeholder={t('apiKeyMessage')} />
93+
</Form.Item>
94+
<Form.Item<FieldType>
95+
label={t('modelName')}
96+
name="llm_name"
97+
initialValue="gpt-3.5-turbo"
98+
rules={[{ required: true, message: t('modelNameMessage') }]}
99+
>
100+
<Input placeholder={t('modelNameMessage')} />
101+
</Form.Item>
102+
<Form.Item<FieldType>
103+
label={t('apiVersion')}
104+
name="api_version"
105+
initialValue="2024-02-01"
106+
rules={[{ required: false, message: t('apiVersionMessage') }]}
107+
>
108+
<Input placeholder={t('apiVersionMessage')} />
109+
</Form.Item>
110+
<Form.Item noStyle dependencies={['model_type']}>
111+
{({ getFieldValue }) =>
112+
getFieldValue('model_type') === 'chat' && (
113+
<Form.Item
114+
label={t('vision')}
115+
valuePropName="checked"
116+
name={'vision'}
117+
>
118+
<Switch />
119+
</Form.Item>
120+
)
121+
}
122+
</Form.Item>
123+
</Form>
124+
</Modal>
125+
);
126+
};
127+
128+
export default AzureOpenAIModal;

web/src/pages/user-setting/setting-model/hooks.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,33 @@ export const useSubmitBedrock = () => {
353353
};
354354
};
355355

356+
export const useSubmitAzure = () => {
357+
const { addLlm, loading } = useAddLlm();
358+
const {
359+
visible: AzureAddingVisible,
360+
hideModal: hideAzureAddingModal,
361+
showModal: showAzureAddingModal,
362+
} = useSetModalState();
363+
364+
const onAzureAddingOk = useCallback(
365+
async (payload: IAddLlmRequestBody) => {
366+
const ret = await addLlm(payload);
367+
if (ret === 0) {
368+
hideAzureAddingModal();
369+
}
370+
},
371+
[hideAzureAddingModal, addLlm],
372+
);
373+
374+
return {
375+
AzureAddingLoading: loading,
376+
onAzureAddingOk,
377+
AzureAddingVisible,
378+
hideAzureAddingModal,
379+
showAzureAddingModal,
380+
};
381+
};
382+
356383
export const useHandleDeleteLlm = (llmFactory: string) => {
357384
const { deleteLlm } = useDeleteLlm();
358385
const showDeleteConfirm = useShowDeleteConfirm();

web/src/pages/user-setting/setting-model/index.tsx

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import SettingTitle from '../components/setting-title';
2929
import { isLocalLlmFactory } from '../utils';
3030
import TencentCloudModal from './Tencent-modal';
3131
import ApiKeyModal from './api-key-modal';
32+
import AzureOpenAIModal from './azure-openai-modal';
3233
import BedrockModal from './bedrock-modal';
3334
import { IconMap } from './constant';
3435
import FishAudioModal from './fish-audio-modal';
@@ -37,6 +38,7 @@ import {
3738
useHandleDeleteFactory,
3839
useHandleDeleteLlm,
3940
useSubmitApiKey,
41+
useSubmitAzure,
4042
useSubmitBedrock,
4143
useSubmitFishAudio,
4244
useSubmitGoogle,
@@ -109,7 +111,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
109111
item.name === 'BaiduYiyan' ||
110112
item.name === 'Fish Audio' ||
111113
item.name === 'Tencent Cloud' ||
112-
item.name === 'Google Cloud'
114+
item.name === 'Google Cloud' ||
115+
item.name === 'Azure OpenAI'
113116
? t('addTheModel')
114117
: 'API-Key'}
115118
<SettingOutlined />
@@ -242,6 +245,14 @@ const UserSettingModel = () => {
242245
showBedrockAddingModal,
243246
} = useSubmitBedrock();
244247

248+
const {
249+
AzureAddingVisible,
250+
hideAzureAddingModal,
251+
showAzureAddingModal,
252+
onAzureAddingOk,
253+
AzureAddingLoading,
254+
} = useSubmitAzure();
255+
245256
const ModalMap = useMemo(
246257
() => ({
247258
Bedrock: showBedrockAddingModal,
@@ -252,6 +263,7 @@ const UserSettingModel = () => {
252263
'Fish Audio': showFishAudioAddingModal,
253264
'Tencent Cloud': showTencentCloudAddingModal,
254265
'Google Cloud': showGoogleAddingModal,
266+
'Azure-OpenAI': showAzureAddingModal,
255267
}),
256268
[
257269
showBedrockAddingModal,
@@ -262,6 +274,7 @@ const UserSettingModel = () => {
262274
showyiyanAddingModal,
263275
showFishAudioAddingModal,
264276
showGoogleAddingModal,
277+
showAzureAddingModal,
265278
],
266279
);
267280

@@ -435,6 +448,13 @@ const UserSettingModel = () => {
435448
loading={bedrockAddingLoading}
436449
llmFactory={'Bedrock'}
437450
></BedrockModal>
451+
<AzureOpenAIModal
452+
visible={AzureAddingVisible}
453+
hideModal={hideAzureAddingModal}
454+
onOk={onAzureAddingOk}
455+
loading={AzureAddingLoading}
456+
llmFactory={'Azure-OpenAI'}
457+
></AzureOpenAIModal>
438458
</section>
439459
);
440460
};

web/src/pages/user-setting/setting-model/ollama-modal/index.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ const OllamaModal = ({
101101
<Form.Item<FieldType>
102102
label={t('modelType')}
103103
name="model_type"
104-
initialValue={'chat'}
104+
initialValue={'embedding'}
105105
rules={[{ required: true, message: t('modelTypeMessage') }]}
106106
>
107107
<Select placeholder={t('modelTypeMessage')}>

0 commit comments

Comments
 (0)