Skip to content

Commit ac384d8

Browse files
authored
added base_url to openai.py
added ability to change base_url (and set the default as localhost:5000)
1 parent ec4fe97 commit ac384d8

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

swirl/openai/openai.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
11
from django.conf import settings
2-
32
import logging
43
import os
4+
55
logger = logging.getLogger(__name__)
66

77
MODEL_3 = "gpt-3.5-turbo"
88
MODEL_4 = "gpt-4"
99
MODEL = MODEL_4
1010

11-
AI_RAG_USE = "AI_RAG_USE"
12-
AI_REWRITE_USE = "AI_REWRITE_USE"
11+
AI_RAG_USE = "AI_RAG_USE"
12+
AI_REWRITE_USE = "AI_REWRITE_USE"
1313
AI_QUERY_USE = "AI_QUERY_USE"
1414

1515
class OpenAIClient:
1616
"""
1717
Encapsulates the logic for initializing different types of AI clients,
1818
abstracts the complexity of client creation, and provides clear access to the
19-
client.
20-
throws ValueError when no key is configured or passed in.
19+
client. Sets a default base URL for the OpenAI client to localhost:5000,
20+
but allows overriding with a custom base URL.
21+
Throws ValueError when no key is configured or passed in.
2122
"""
22-
def __init__(self, usage, key=None):
23+
def __init__(self, usage, key=None, base_url="http://localhost:5000"):
2324
if usage not in [AI_RAG_USE, AI_REWRITE_USE, AI_QUERY_USE]:
2425
raise NotImplementedError(f"Unknown AI {usage}. Client initialization not supported.")
2526

2627
self._usage = usage
27-
self._openapi_key = getattr(settings, 'OPENAI_API_KEY', None)
28-
self._azure_model = getattr(settings, "AZURE_MODEL", None)
29-
self._azureapi_key = getattr(settings, 'AZURE_OPENAI_KEY', None)
30-
self._azure_endpoint = getattr(settings, "AZURE_OPENAI_ENDPOINT", None)
31-
self._swirl_rw_model = getattr(settings, "SWIRL_REWRITE_MODEL", None)
32-
self._swirl_q_model = getattr(settings,"SWIRL_QUERY_MODEL", None)
33-
self._swirl_rag_model = getattr(settings,'SWIRL_RAG_MODEL',None)
28+
self._openapi_key = getattr(settings, 'OPENAI_API_KEY', None)
29+
self._azure_model = getattr(settings, "AZURE_MODEL", None)
30+
self._azureapi_key = getattr(settings, 'AZURE_OPENAI_KEY', None)
31+
self._azure_endpoint = getattr(settings, "AZURE_OPENAI_ENDPOINT", None)
32+
self._swirl_rw_model = getattr(settings, "SWIRL_REWRITE_MODEL", None)
33+
self._swirl_q_model = getattr(settings, "SWIRL_QUERY_MODEL", None)
34+
self._swirl_rag_model = getattr(settings, 'SWIRL_RAG_MODEL', None)
35+
self._base_url = base_url # Set the default base URL here
3436

3537
logger.debug(f'cons config : {self._openapi_key} {self._azure_model}'
36-
f'{self._azureapi_key} {self._azure_endpoint} {self._swirl_rw_model} {self._swirl_q_model} {self._swirl_rag_model}')
38+
f'{self._azureapi_key} {self._azure_endpoint} {self._swirl_rw_model} {self._swirl_q_model} {self._swirl_rag_model}')
3739

3840
self._api_key = None
3941
self._api_provider = None
@@ -51,22 +53,23 @@ def __init__(self, usage, key=None):
5153

5254
def _init_openai_client(self, provider, key):
5355
ai_client = None
54-
logger.debug(f'init_openai_client: {provider} {key}')
56+
logger.debug(f'init_openai_client: {provider} {key} {self._base_url}')
5557
try:
5658
if provider == "OPENAI":
5759
from openai import OpenAI
5860
ai_client = OpenAI(api_key=key)
61+
ai_client.base_url = self._base_url # Apply the base URL here
5962
elif provider == "AZUREAI":
6063
from openai import AzureOpenAI
6164
ai_client = AzureOpenAI(api_key=key, azure_endpoint=self._azure_endpoint, api_version="2023-10-01-preview")
6265
else:
6366
raise NotImplementedError(f"Unknown AI provider {provider}. Client initialization not supported.")
6467
except Exception as err:
65-
raise err
68+
raise err
6669
return ai_client
6770

6871
def get_model(self):
69-
# If the provder is AZURE and the az model is set, use it.
72+
# If the provider is AZURE and the azure model is set, use it.
7073
logger.info(f'get model {self._api_provider} {self._azure_model}')
7174
if self._api_provider == 'AZUREAI' and self._azure_model:
7275
return self._azure_model
@@ -80,7 +83,7 @@ def get_model(self):
8083
return self._swirl_rag_model
8184

8285
def get_encoding_model(self):
83-
# otherwise use models as per usage
86+
# Use models as per usage
8487
if self._usage == AI_REWRITE_USE:
8588
return self._swirl_rw_model
8689
elif self._usage == AI_QUERY_USE:

0 commit comments

Comments
 (0)