Skip to content

Commit 9166109

Browse files
DH-5776/fixing the azure openai (#487)
* DH-5776/fixing the azure openai * Fixing the linter * reformat with black
1 parent e0cf408 commit 9166109

File tree

14 files changed

+93
-78
lines changed

14 files changed

+93
-78
lines changed

services/engine/dataherald/api/fastapi.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def async_scanning(scanner, database, table_descriptions, storage):
110110
)
111111

112112

113-
def async_fine_tuning(storage, model):
114-
openai_fine_tuning = OpenAIFineTuning(storage, model)
113+
def async_fine_tuning(system, storage, model):
114+
openai_fine_tuning = OpenAIFineTuning(system, storage, model)
115115
openai_fine_tuning.create_fintuning_dataset()
116116
openai_fine_tuning.create_fine_tuning_job()
117117

@@ -626,7 +626,7 @@ def create_finetuning_job(
626626
e, fine_tuning_request.dict(), "finetuning_not_created"
627627
)
628628

629-
background_tasks.add_task(async_fine_tuning, self.storage, model)
629+
background_tasks.add_task(async_fine_tuning, self.system, self.storage, model)
630630

631631
return model
632632

@@ -652,7 +652,7 @@ def cancel_finetuning_job(
652652
status_code=400, detail="Model has already been cancelled."
653653
)
654654

655-
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
655+
openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model)
656656

657657
return openai_fine_tuning.cancel_finetuning_job()
658658

@@ -665,7 +665,7 @@ def get_finetunings(self, db_connection_id: str | None = None) -> list[Finetunin
665665
models = model_repository.find_by(query)
666666
result = []
667667
for model in models:
668-
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
668+
openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model)
669669
result.append(
670670
Finetuning(**openai_fine_tuning.retrieve_finetuning_job().dict())
671671
)
@@ -685,7 +685,7 @@ def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
685685
model = model_repository.find_by_id(finetuning_job_id)
686686
if not model:
687687
raise HTTPException(status_code=404, detail="Model not found")
688-
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
688+
openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model)
689689
return openai_fine_tuning.retrieve_finetuning_job()
690690

691691
@override

services/engine/dataherald/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class Settings(BaseSettings):
4545
encrypt_key: str = os.environ.get("ENCRYPT_KEY")
4646
s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID")
4747
s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
48-
#Needed for Azure OpenAI integration:
48+
# Needed for Azure OpenAI integration:
4949
azure_api_key: str | None = os.environ.get("AZURE_API_KEY")
5050
embedding_model: str | None = os.environ.get("EMBEDDING_MODEL")
5151
azure_api_version: str | None = os.environ.get("AZURE_API_VERSION")

services/engine/dataherald/finetuning/openai_finetuning.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
import numpy as np
99
import tiktoken
10-
from langchain_openai import OpenAIEmbeddings
10+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1111
from openai import OpenAI
1212
from overrides import override
1313
from sql_metadata import Parser
1414
from tiktoken import Encoding
1515

16+
from dataherald.config import System
1617
from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus
1718
from dataherald.db_scanner.repository.base import TableDescriptionRepository
1819
from dataherald.finetuning import FinetuningModel
@@ -36,17 +37,24 @@ class OpenAIFineTuning(FinetuningModel):
3637
storage: Any
3738
client: OpenAI
3839

39-
def __init__(self, storage: Any, fine_tuning_model: Finetuning):
40+
def __init__(self, system: System, storage: Any, fine_tuning_model: Finetuning):
4041
self.storage = storage
42+
self.system = system
4143
self.fine_tuning_model = fine_tuning_model
4244
db_connection_repository = DatabaseConnectionRepository(storage)
4345
db_connection = db_connection_repository.find_by_id(
4446
fine_tuning_model.db_connection_id
4547
)
46-
self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
47-
openai_api_key=db_connection.decrypt_api_key(),
48-
model=EMBEDDING_MODEL,
49-
)
48+
if self.system.settings["azure_api_key"] is not None:
49+
self.embedding = AzureOpenAIEmbeddings(
50+
azure_api_key=db_connection.decrypt_api_key(),
51+
model=EMBEDDING_MODEL,
52+
)
53+
else:
54+
self.embedding = OpenAIEmbeddings(
55+
openai_api_key=db_connection.decrypt_api_key(),
56+
model=EMBEDDING_MODEL,
57+
)
5058
self.encoding = tiktoken.encoding_for_model(
5159
fine_tuning_model.base_llm.model_name
5260
)

services/engine/dataherald/model/base_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any
33

4-
from langchain.llms import AlephAlpha, Anthropic, Cohere, OpenAI
4+
from langchain.llms import AlephAlpha, Anthropic, AzureOpenAI, Cohere, OpenAI
55
from overrides import override
66

77
from dataherald.model import LLMModel
@@ -19,16 +19,16 @@ def __init__(self, system):
1919
self.azure_api_key = os.environ.get("AZURE_API_KEY")
2020

2121
@override
22-
def get_model(
22+
def get_model( # noqa: C901
2323
self,
2424
database_connection: DatabaseConnection,
2525
model_family="openai",
2626
model_name="davinci-003",
2727
api_base: str | None = None, # noqa: ARG002
2828
**kwargs: Any
2929
) -> Any:
30-
if self.system.settings['azure_api_key'] != None:
31-
model_family = 'azure'
30+
if self.system.settings["azure_api_key"] is not None:
31+
model_family = "azure"
3232
if database_connection.llm_api_key is not None:
3333
fernet_encrypt = FernetEncrypt()
3434
api_key = fernet_encrypt.decrypt(database_connection.llm_api_key)
@@ -39,7 +39,7 @@ def get_model(
3939
elif model_family == "google":
4040
self.google_api_key = api_key
4141
elif model_family == "azure":
42-
self.azure_api_key == api_key
42+
self.azure_api_key = api_key
4343
if self.openai_api_key:
4444
self.model = OpenAI(model_name=model_name, **kwargs)
4545
elif self.aleph_alpha_api_key:

services/engine/dataherald/model/chat_model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22

33
from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm
4-
from langchain_openai import ChatOpenAI, AzureChatOpenAI
4+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
55
from overrides import override
66

77
from dataherald.model import LLMModel
@@ -22,16 +22,16 @@ def get_model(
2222
**kwargs: Any
2323
) -> Any:
2424
api_key = database_connection.decrypt_api_key()
25-
if self.system.settings['azure_api_key'] != None:
26-
model_family = 'azure'
25+
if self.system.settings["azure_api_key"] is not None:
26+
model_family = "azure"
2727
if model_family == "azure":
28-
if api_base.endswith("/"): #TODO check where final "/" is added to api_base
28+
if api_base.endswith("/"): # check where final "/" is added to api_base
2929
api_base = api_base[:-1]
3030
return AzureChatOpenAI(
3131
deployment_name=model_name,
3232
openai_api_key=api_key,
33-
azure_endpoint= api_base,
34-
api_version=self.system.settings['azure_api_version'],
33+
azure_endpoint=api_base,
34+
api_version=self.system.settings["azure_api_version"],
3535
**kwargs
3636
)
3737
if model_family == "openai":

services/engine/dataherald/services/sql_generations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def update_the_initial_sql_generation(
6363
initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps
6464
return self.sql_generation_repository.update(initial_sql_generation)
6565

66-
def create(
66+
def create( # noqa: PLR0912
6767
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
68-
) -> SQLGeneration:
68+
) -> SQLGeneration: # noqa: PLR0912
6969
initial_sql_generation = SQLGeneration(
7070
prompt_id=prompt_id,
7171
created_at=datetime.now(),

services/engine/dataherald/sql_generator/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,15 @@ def generate_response(
179179
"""Generates a response to a user question."""
180180
pass
181181

182-
def stream_agent_steps( # noqa: C901
182+
def stream_agent_steps( # noqa: PLR0912, C901
183183
self,
184184
question: str,
185185
agent_executor: AgentExecutor,
186186
response: SQLGeneration,
187187
sql_generation_repository: SQLGenerationRepository,
188188
queue: Queue,
189189
metadata: dict = None,
190-
):
190+
): # noqa: PLR0912
191191
try:
192192
with get_openai_callback() as cb:
193193
for chunk in agent_executor.stream(

services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from langchain.chains.llm import LLMChain
2222
from langchain.tools.base import BaseTool
2323
from langchain_community.callbacks import get_openai_callback
24-
from langchain_openai import OpenAIEmbeddings
24+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
2525
from openai import OpenAI
2626
from overrides import override
2727
from pydantic import BaseModel, Field
@@ -587,14 +587,24 @@ def generate_response(
587587
)
588588
finetunings_repository = FinetuningsRepository(storage)
589589
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
590-
openai_fine_tuning = OpenAIFineTuning(storage, finetuning)
590+
openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning)
591591
finetuning = openai_fine_tuning.retrieve_finetuning_job()
592592
if finetuning.status != FineTuningStatus.SUCCEEDED.value:
593593
raise FinetuningNotAvailableError(
594594
f"Finetuning({self.finetuning_id}) has the status {finetuning.status}."
595595
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
596596
)
597597
self.database = SQLDatabase.get_sql_engine(database_connection)
598+
if self.system.settings["azure_api_key"] is not None:
599+
embedding = AzureOpenAIEmbeddings(
600+
openai_api_key=database_connection.decrypt_api_key(),
601+
model=EMBEDDING_MODEL,
602+
)
603+
else:
604+
embedding = OpenAIEmbeddings(
605+
openai_api_key=database_connection.decrypt_api_key(),
606+
model=EMBEDDING_MODEL,
607+
)
598608
toolkit = SQLDatabaseToolkit(
599609
db=self.database,
600610
instructions=instructions,
@@ -605,10 +615,7 @@ def generate_response(
605615
use_finetuned_model_only=self.use_fintuned_model_only,
606616
model_name=finetuning.base_llm.model_name,
607617
openai_fine_tuning=openai_fine_tuning,
608-
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
609-
openai_api_key=database_connection.decrypt_api_key(),
610-
model=EMBEDDING_MODEL,
611-
),
618+
embedding=embedding,
612619
)
613620
agent_executor = self.create_sql_agent(
614621
toolkit=toolkit,
@@ -693,14 +700,24 @@ def stream_response(
693700
)
694701
finetunings_repository = FinetuningsRepository(storage)
695702
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
696-
openai_fine_tuning = OpenAIFineTuning(storage, finetuning)
703+
openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning)
697704
finetuning = openai_fine_tuning.retrieve_finetuning_job()
698705
if finetuning.status != FineTuningStatus.SUCCEEDED.value:
699706
raise FinetuningNotAvailableError(
700707
f"Finetuning({self.finetuning_id}) has the status {finetuning.status}."
701708
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
702709
)
703710
self.database = SQLDatabase.get_sql_engine(database_connection)
711+
if self.system.settings["azure_api_key"] is not None:
712+
embedding = AzureOpenAIEmbeddings(
713+
openai_api_key=database_connection.decrypt_api_key(),
714+
model=EMBEDDING_MODEL,
715+
)
716+
else:
717+
embedding = OpenAIEmbeddings(
718+
openai_api_key=database_connection.decrypt_api_key(),
719+
model=EMBEDDING_MODEL,
720+
)
704721
toolkit = SQLDatabaseToolkit(
705722
db=self.database,
706723
instructions=instructions,
@@ -710,10 +727,7 @@ def stream_response(
710727
use_finetuned_model_only=self.use_fintuned_model_only,
711728
model_name=finetuning.base_llm.model_name,
712729
openai_fine_tuning=openai_fine_tuning,
713-
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
714-
openai_api_key=database_connection.decrypt_api_key(),
715-
model=EMBEDDING_MODEL,
716-
),
730+
embedding=embedding,
717731
)
718732
agent_executor = self.create_sql_agent(
719733
toolkit=toolkit,

services/engine/dataherald/sql_generator/dataherald_sqlagent.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from langchain.chains.llm import LLMChain
2323
from langchain.tools.base import BaseTool
2424
from langchain_community.callbacks import get_openai_callback
25-
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
25+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
2626
from overrides import override
2727
from pydantic import BaseModel, Field
2828
from sql_metadata import Parser
@@ -710,13 +710,13 @@ def create_sql_agent(
710710
)
711711

712712
@override
713-
def generate_response(
713+
def generate_response( # noqa: PLR0912
714714
self,
715715
user_prompt: Prompt,
716716
database_connection: DatabaseConnection,
717717
context: List[dict] = None,
718718
metadata: dict = None,
719-
) -> SQLGeneration:
719+
) -> SQLGeneration: # noqa: PLR0912
720720
context_store = self.system.instance(ContextStore)
721721
storage = self.system.instance(DB)
722722
response = SQLGeneration(
@@ -753,8 +753,8 @@ def generate_response(
753753
number_of_samples = 0
754754
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
755755
self.database = SQLDatabase.get_sql_engine(database_connection)
756-
#Set Embeddings class depending on azure / not azure
757-
if self.llm.openai_api_type == "azure":
756+
# Set Embeddings class depending on azure / not azure
757+
if self.system.settings["azure_api_key"] is not None:
758758
toolkit = SQLDatabaseToolkit(
759759
db=self.database,
760760
context=context,
@@ -873,21 +873,17 @@ def stream_response(
873873
new_fewshot_examples = None
874874
number_of_samples = 0
875875
self.database = SQLDatabase.get_sql_engine(database_connection)
876-
#Set Embeddings class depending on azure / not azure
877-
if self.llm.openai_api_type == "azure":
878-
toolkit = SQLDatabaseToolkit(
879-
db=self.database,
880-
context=context,
881-
few_shot_examples=new_fewshot_examples,
882-
instructions=instructions,
883-
is_multiple_schema=True if user_prompt.schemas else False,
884-
db_scan=db_scan,
885-
embedding=AzureOpenAIEmbeddings(
886-
openai_api_key=database_connection.decrypt_api_key(),
887-
model=EMBEDDING_MODEL,
888-
),
876+
# Set Embeddings class depending on azure / not azure
877+
if self.system.settings["azure_api_key"] is not None:
878+
embedding = AzureOpenAIEmbeddings(
879+
openai_api_key=database_connection.decrypt_api_key(),
880+
model=EMBEDDING_MODEL,
881+
)
882+
else:
883+
embedding = OpenAIEmbeddings(
884+
openai_api_key=database_connection.decrypt_api_key(),
885+
model=EMBEDDING_MODEL,
889886
)
890-
else:
891887
toolkit = SQLDatabaseToolkit(
892888
queuer=queue,
893889
db=self.database,
@@ -896,10 +892,7 @@ def stream_response(
896892
instructions=instructions,
897893
is_multiple_schema=True if user_prompt.schemas else False,
898894
db_scan=db_scan,
899-
embedding=OpenAIEmbeddings(
900-
openai_api_key=database_connection.decrypt_api_key(),
901-
model=EMBEDDING_MODEL,
902-
),
895+
embedding=embedding,
903896
)
904897
agent_executor = self.create_sql_agent(
905898
toolkit=toolkit,

services/enterprise/exceptions/exception_handlers.py

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414

1515
async def exception_handler(request: Request, exc: BaseError): # noqa: ARG001
16-
1716
trace_id = exc.trace_id
1817
error_code = exc.error_code
1918
status_code = exc.status_code

services/enterprise/exceptions/exceptions.py

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(
3939
description: str = None,
4040
detail: dict = None,
4141
) -> None:
42-
4342
if type(self) is BaseError:
4443
raise TypeError("BaseError class may not be instantiated directly")
4544

services/enterprise/modules/db_connection/controller.py

-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ async def ac_get_db_connection(
9595
id: ObjectIdString,
9696
user: User = Security(authenticate_user),
9797
) -> DBConnectionResponse:
98-
9998
return db_connection_service.get_db_connection(id, user.organization_id)
10099

101100

0 commit comments

Comments
 (0)