Skip to content

Commit ede31c7

Browse files
DH-5776/add support for Azure embedding
1 parent a66c5f9 commit ede31c7

File tree

6 files changed

+63
-48
lines changed

6 files changed

+63
-48
lines changed

Diff for: dataherald/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class Settings(BaseSettings):
4545
encrypt_key: str = os.environ.get("ENCRYPT_KEY")
4646
s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID")
4747
s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
48-
#Needed for Azure OpenAI integration:
48+
# Needed for Azure OpenAI integration:
4949
azure_api_key: str | None = os.environ.get("AZURE_API_KEY")
5050
embedding_model: str | None = os.environ.get("EMBEDDING_MODEL")
5151
azure_api_version: str | None = os.environ.get("AZURE_API_VERSION")

Diff for: dataherald/finetuning/openai_finetuning.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
import numpy as np
99
import tiktoken
10-
from langchain_openai import OpenAIEmbeddings
10+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1111
from openai import OpenAI
1212
from overrides import override
1313
from sql_metadata import Parser
1414
from tiktoken import Encoding
1515

16+
from dataherald.config import System
1617
from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus
1718
from dataherald.db_scanner.repository.base import TableDescriptionRepository
1819
from dataherald.finetuning import FinetuningModel
@@ -36,17 +37,24 @@ class OpenAIFineTuning(FinetuningModel):
3637
storage: Any
3738
client: OpenAI
3839

39-
def __init__(self, storage: Any, fine_tuning_model: Finetuning):
40+
def __init__(self, system: System, storage: Any, fine_tuning_model: Finetuning):
4041
self.storage = storage
42+
self.system = system
4143
self.fine_tuning_model = fine_tuning_model
4244
db_connection_repository = DatabaseConnectionRepository(storage)
4345
db_connection = db_connection_repository.find_by_id(
4446
fine_tuning_model.db_connection_id
4547
)
46-
self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
47-
openai_api_key=db_connection.decrypt_api_key(),
48-
model=EMBEDDING_MODEL,
49-
)
48+
if self.system.settings["azure_api_key"] is not None:
49+
self.embedding = AzureOpenAIEmbeddings(
50+
azure_api_key=db_connection.decrypt_api_key(),
51+
model=EMBEDDING_MODEL,
52+
)
53+
else:
54+
self.embedding = OpenAIEmbeddings(
55+
openai_api_key=db_connection.decrypt_api_key(),
56+
model=EMBEDDING_MODEL,
57+
)
5058
self.encoding = tiktoken.encoding_for_model(
5159
fine_tuning_model.base_llm.model_name
5260
)

Diff for: dataherald/model/base_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Any
33

4-
from langchain.llms import AlephAlpha, Anthropic, Cohere, OpenAI
4+
from langchain.llms import AlephAlpha, Anthropic, AzureOpenAI, Cohere, OpenAI
55
from overrides import override
66

77
from dataherald.model import LLMModel
@@ -19,16 +19,16 @@ def __init__(self, system):
1919
self.azure_api_key = os.environ.get("AZURE_API_KEY")
2020

2121
@override
22-
def get_model(
22+
def get_model( # noqa: C901
2323
self,
2424
database_connection: DatabaseConnection,
2525
model_family="openai",
2626
model_name="davinci-003",
2727
api_base: str | None = None, # noqa: ARG002
2828
**kwargs: Any
2929
) -> Any:
30-
if self.system.settings['azure_api_key'] != None:
31-
model_family = 'azure'
30+
if self.system.settings["azure_api_key"] is not None:
31+
model_family = "azure"
3232
if database_connection.llm_api_key is not None:
3333
fernet_encrypt = FernetEncrypt()
3434
api_key = fernet_encrypt.decrypt(database_connection.llm_api_key)
@@ -39,7 +39,7 @@ def get_model(
3939
elif model_family == "google":
4040
self.google_api_key = api_key
4141
elif model_family == "azure":
42-
self.azure_api_key == api_key
42+
self.azure_api_key = api_key
4343
if self.openai_api_key:
4444
self.model = OpenAI(model_name=model_name, **kwargs)
4545
elif self.aleph_alpha_api_key:

Diff for: dataherald/model/chat_model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22

33
from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm
4-
from langchain_openai import ChatOpenAI, AzureChatOpenAI
4+
from langchain_openai import AzureChatOpenAI, ChatOpenAI
55
from overrides import override
66

77
from dataherald.model import LLMModel
@@ -22,16 +22,16 @@ def get_model(
2222
**kwargs: Any
2323
) -> Any:
2424
api_key = database_connection.decrypt_api_key()
25-
if self.system.settings['azure_api_key'] != None:
26-
model_family = 'azure'
25+
if self.system.settings["azure_api_key"] is not None:
26+
model_family = "azure"
2727
if model_family == "azure":
28-
if api_base.endswith("/"): #TODO check where final "/" is added to api_base
28+
if api_base.endswith("/"): # check where final "/" is added to api_base
2929
api_base = api_base[:-1]
3030
return AzureChatOpenAI(
3131
deployment_name=model_name,
3232
openai_api_key=api_key,
33-
azure_endpoint= api_base,
34-
api_version=self.system.settings['azure_api_version'],
33+
azure_endpoint=api_base,
34+
api_version=self.system.settings["azure_api_version"],
3535
**kwargs
3636
)
3737
if model_family == "openai":

Diff for: dataherald/sql_generator/dataherald_finetuning_agent.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from langchain.chains.llm import LLMChain
2222
from langchain.tools.base import BaseTool
2323
from langchain_community.callbacks import get_openai_callback
24-
from langchain_openai import OpenAIEmbeddings
24+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
2525
from openai import OpenAI
2626
from overrides import override
2727
from pydantic import BaseModel, Field
@@ -587,14 +587,24 @@ def generate_response(
587587
)
588588
finetunings_repository = FinetuningsRepository(storage)
589589
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
590-
openai_fine_tuning = OpenAIFineTuning(storage, finetuning)
590+
openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning)
591591
finetuning = openai_fine_tuning.retrieve_finetuning_job()
592592
if finetuning.status != FineTuningStatus.SUCCEEDED.value:
593593
raise FinetuningNotAvailableError(
594594
f"Finetuning({self.finetuning_id}) has the status {finetuning.status}."
595595
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
596596
)
597597
self.database = SQLDatabase.get_sql_engine(database_connection)
598+
if self.llm.openai_api_type == "azure":
599+
embedding = AzureOpenAIEmbeddings(
600+
openai_api_key=database_connection.decrypt_api_key(),
601+
model=EMBEDDING_MODEL,
602+
)
603+
else:
604+
embedding = OpenAIEmbeddings(
605+
openai_api_key=database_connection.decrypt_api_key(),
606+
model=EMBEDDING_MODEL,
607+
)
598608
toolkit = SQLDatabaseToolkit(
599609
db=self.database,
600610
instructions=instructions,
@@ -605,10 +615,7 @@ def generate_response(
605615
use_finetuned_model_only=self.use_fintuned_model_only,
606616
model_name=finetuning.base_llm.model_name,
607617
openai_fine_tuning=openai_fine_tuning,
608-
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
609-
openai_api_key=database_connection.decrypt_api_key(),
610-
model=EMBEDDING_MODEL,
611-
),
618+
embedding=embedding,
612619
)
613620
agent_executor = self.create_sql_agent(
614621
toolkit=toolkit,
@@ -693,14 +700,24 @@ def stream_response(
693700
)
694701
finetunings_repository = FinetuningsRepository(storage)
695702
finetuning = finetunings_repository.find_by_id(self.finetuning_id)
696-
openai_fine_tuning = OpenAIFineTuning(storage, finetuning)
703+
openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning)
697704
finetuning = openai_fine_tuning.retrieve_finetuning_job()
698705
if finetuning.status != FineTuningStatus.SUCCEEDED.value:
699706
raise FinetuningNotAvailableError(
700707
f"Finetuning({self.finetuning_id}) has the status {finetuning.status}."
701708
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
702709
)
703710
self.database = SQLDatabase.get_sql_engine(database_connection)
711+
if self.llm.openai_api_type == "azure":
712+
embedding = AzureOpenAIEmbeddings(
713+
openai_api_key=database_connection.decrypt_api_key(),
714+
model=EMBEDDING_MODEL,
715+
)
716+
else:
717+
embedding = OpenAIEmbeddings(
718+
openai_api_key=database_connection.decrypt_api_key(),
719+
model=EMBEDDING_MODEL,
720+
)
704721
toolkit = SQLDatabaseToolkit(
705722
db=self.database,
706723
instructions=instructions,
@@ -710,10 +727,7 @@ def stream_response(
710727
use_finetuned_model_only=self.use_fintuned_model_only,
711728
model_name=finetuning.base_llm.model_name,
712729
openai_fine_tuning=openai_fine_tuning,
713-
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
714-
openai_api_key=database_connection.decrypt_api_key(),
715-
model=EMBEDDING_MODEL,
716-
),
730+
embedding=embedding,
717731
)
718732
agent_executor = self.create_sql_agent(
719733
toolkit=toolkit,

Diff for: dataherald/sql_generator/dataherald_sqlagent.py

+12-19
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from langchain.chains.llm import LLMChain
2323
from langchain.tools.base import BaseTool
2424
from langchain_community.callbacks import get_openai_callback
25-
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
25+
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
2626
from overrides import override
2727
from pydantic import BaseModel, Field
2828
from sql_metadata import Parser
@@ -753,7 +753,7 @@ def generate_response(
753753
number_of_samples = 0
754754
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
755755
self.database = SQLDatabase.get_sql_engine(database_connection)
756-
#Set Embeddings class depending on azure / not azure
756+
# Set Embeddings class depending on azure / not azure
757757
if self.llm.openai_api_type == "azure":
758758
toolkit = SQLDatabaseToolkit(
759759
db=self.database,
@@ -873,21 +873,17 @@ def stream_response(
873873
new_fewshot_examples = None
874874
number_of_samples = 0
875875
self.database = SQLDatabase.get_sql_engine(database_connection)
876-
#Set Embeddings class depending on azure / not azure
876+
# Set Embeddings class depending on azure / not azure
877877
if self.llm.openai_api_type == "azure":
878-
toolkit = SQLDatabaseToolkit(
879-
db=self.database,
880-
context=context,
881-
few_shot_examples=new_fewshot_examples,
882-
instructions=instructions,
883-
is_multiple_schema=True if user_prompt.schemas else False,
884-
db_scan=db_scan,
885-
embedding=AzureOpenAIEmbeddings(
886-
openai_api_key=database_connection.decrypt_api_key(),
887-
model=EMBEDDING_MODEL,
888-
),
878+
embedding = AzureOpenAIEmbeddings(
879+
openai_api_key=database_connection.decrypt_api_key(),
880+
model=EMBEDDING_MODEL,
881+
)
882+
else:
883+
embedding = OpenAIEmbeddings(
884+
openai_api_key=database_connection.decrypt_api_key(),
885+
model=EMBEDDING_MODEL,
889886
)
890-
else:
891887
toolkit = SQLDatabaseToolkit(
892888
queuer=queue,
893889
db=self.database,
@@ -896,10 +892,7 @@ def stream_response(
896892
instructions=instructions,
897893
is_multiple_schema=True if user_prompt.schemas else False,
898894
db_scan=db_scan,
899-
embedding=OpenAIEmbeddings(
900-
openai_api_key=database_connection.decrypt_api_key(),
901-
model=EMBEDDING_MODEL,
902-
),
895+
embedding=embedding,
903896
)
904897
agent_executor = self.create_sql_agent(
905898
toolkit=toolkit,

0 commit comments

Comments
 (0)