1
1
from django .conf import settings
2
-
3
2
import logging
4
3
import os
4
+
5
5
logger = logging .getLogger (__name__ )
6
6
7
7
MODEL_3 = "gpt-3.5-turbo"
8
8
MODEL_4 = "gpt-4"
9
9
MODEL = MODEL_4
10
10
11
- AI_RAG_USE = "AI_RAG_USE"
12
- AI_REWRITE_USE = "AI_REWRITE_USE"
11
+ AI_RAG_USE = "AI_RAG_USE"
12
+ AI_REWRITE_USE = "AI_REWRITE_USE"
13
13
AI_QUERY_USE = "AI_QUERY_USE"
14
14
15
15
class OpenAIClient :
16
16
"""
17
17
Encapsulates the logic for initializing different types of AI clients,
18
18
abstracts the complexity of client creation, and provides clear access to the
19
- client.
20
- throws ValueError when no key is configured or passed in.
19
+ client. Sets a default base URL for the OpenAI client to localhost:5000,
20
+ but allows overriding with a custom base URL.
21
+ Throws ValueError when no key is configured or passed in.
21
22
"""
22
- def __init__ (self , usage , key = None ):
23
+ def __init__ (self , usage , key = None , base_url = "http://localhost:5000" ):
23
24
if usage not in [AI_RAG_USE , AI_REWRITE_USE , AI_QUERY_USE ]:
24
25
raise NotImplementedError (f"Unknown AI { usage } . Client initialization not supported." )
25
26
26
27
self ._usage = usage
27
- self ._openapi_key = getattr (settings , 'OPENAI_API_KEY' , None )
28
- self ._azure_model = getattr (settings , "AZURE_MODEL" , None )
29
- self ._azureapi_key = getattr (settings , 'AZURE_OPENAI_KEY' , None )
30
- self ._azure_endpoint = getattr (settings , "AZURE_OPENAI_ENDPOINT" , None )
31
- self ._swirl_rw_model = getattr (settings , "SWIRL_REWRITE_MODEL" , None )
32
- self ._swirl_q_model = getattr (settings ,"SWIRL_QUERY_MODEL" , None )
33
- self ._swirl_rag_model = getattr (settings ,'SWIRL_RAG_MODEL' ,None )
28
+ self ._openapi_key = getattr (settings , 'OPENAI_API_KEY' , None )
29
+ self ._azure_model = getattr (settings , "AZURE_MODEL" , None )
30
+ self ._azureapi_key = getattr (settings , 'AZURE_OPENAI_KEY' , None )
31
+ self ._azure_endpoint = getattr (settings , "AZURE_OPENAI_ENDPOINT" , None )
32
+ self ._swirl_rw_model = getattr (settings , "SWIRL_REWRITE_MODEL" , None )
33
+ self ._swirl_q_model = getattr (settings , "SWIRL_QUERY_MODEL" , None )
34
+ self ._swirl_rag_model = getattr (settings , 'SWIRL_RAG_MODEL' , None )
35
+ self ._base_url = base_url # Set the default base URL here
34
36
35
37
logger .debug (f'cons config : { self ._openapi_key } { self ._azure_model } '
36
- f'{ self ._azureapi_key } { self ._azure_endpoint } { self ._swirl_rw_model } { self ._swirl_q_model } { self ._swirl_rag_model } ' )
38
+ f'{ self ._azureapi_key } { self ._azure_endpoint } { self ._swirl_rw_model } { self ._swirl_q_model } { self ._swirl_rag_model } ' )
37
39
38
40
self ._api_key = None
39
41
self ._api_provider = None
@@ -51,22 +53,23 @@ def __init__(self, usage, key=None):
51
53
52
54
def _init_openai_client (self , provider , key ):
53
55
ai_client = None
54
- logger .debug (f'init_openai_client: { provider } { key } ' )
56
+ logger .debug (f'init_openai_client: { provider } { key } { self . _base_url } ' )
55
57
try :
56
58
if provider == "OPENAI" :
57
59
from openai import OpenAI
58
60
ai_client = OpenAI (api_key = key )
61
+ ai_client .base_url = self ._base_url # Apply the base URL here
59
62
elif provider == "AZUREAI" :
60
63
from openai import AzureOpenAI
61
64
ai_client = AzureOpenAI (api_key = key , azure_endpoint = self ._azure_endpoint , api_version = "2023-10-01-preview" )
62
65
else :
63
66
raise NotImplementedError (f"Unknown AI provider { provider } . Client initialization not supported." )
64
67
except Exception as err :
65
- raise err
68
+ raise err
66
69
return ai_client
67
70
68
71
def get_model (self ):
69
- # If the provder is AZURE and the az model is set, use it.
72
+ # If the provider is AZURE and the azure model is set, use it.
70
73
logger .info (f'get model { self ._api_provider } { self ._azure_model } ' )
71
74
if self ._api_provider == 'AZUREAI' and self ._azure_model :
72
75
return self ._azure_model
@@ -80,7 +83,7 @@ def get_model(self):
80
83
return self ._swirl_rag_model
81
84
82
85
def get_encoding_model (self ):
83
- # otherwise use models as per usage
86
+ # Use models as per usage
84
87
if self ._usage == AI_REWRITE_USE :
85
88
return self ._swirl_rw_model
86
89
elif self ._usage == AI_QUERY_USE :
0 commit comments