Skip to content

Commit 581b6ff

Browse files
Azure OpenAI deployments compatibility (#457)
(non finetunning) Co-authored-by: Julio Navarro <jmnavarro@ferrovial.com>
1 parent 514e498 commit 581b6ff

File tree

9 files changed

+100
-32
lines changed

9 files changed

+100
-32
lines changed

Diff for: README.md

+15
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,21 @@ UPPER_LIMIT_QUERY_RETURN_ROWS = 50
8383
DH_ENGINE_TIMEOUT = 150
8484
```
8585

86+
In case you want to use models deployed in Azure OpenAI, you must set the following variables:
87+
```
88+
AZURE_API_KEY = "xxxxx"
89+
AZURE_OPENAI_API_KEY = "xxxxxx"
90+
API_BASE = "azure_openai_endpoint"
91+
AZURE_OPENAI_ENDPOINT = "azure_openai_endpoint"
92+
AZURE_API_VERSION = "version of the API to use"
93+
LLM_MODEL = "name_of_the_deployment"
94+
```
95+
In addition, an embedding model will be also used. There must be a deployment created with name "text-embedding-3-large".
96+
97+
The existence of AZURE_API_KEY as environment variable indicates Azure models must be used.
98+
99+
Remember to remove comments beside the environment variables.
100+
86101
While not strictly required, we also strongly suggest you change the MONGO username and password fields as well.
87102

88103
Follow the next commands to generate an ENCRYPT_KEY and paste it in the .env file like

Diff for: dataherald/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class Settings(BaseSettings):
4545
encrypt_key: str = os.environ.get("ENCRYPT_KEY")
4646
s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID")
4747
s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
48+
#Needed for Azure OpenAI integration:
49+
azure_api_key: str | None = os.environ.get("AZURE_API_KEY")
50+
embedding_model: str | None = os.environ.get("EMBEDDING_MODEL")
51+
azure_api_version: str | None = os.environ.get("AZURE_API_VERSION")
4852
only_store_csv_files_locally: bool | None = os.environ.get(
4953
"ONLY_STORE_CSV_FILES_LOCALLY", False
5054
)

Diff for: dataherald/finetuning/openai_finetuning.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, storage: Any, fine_tuning_model: Finetuning):
4343
db_connection = db_connection_repository.find_by_id(
4444
fine_tuning_model.db_connection_id
4545
)
46-
self.embedding = OpenAIEmbeddings(
46+
self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
4747
openai_api_key=db_connection.decrypt_api_key(),
4848
model=EMBEDDING_MODEL,
4949
)

Diff for: dataherald/model/base_model.py

+7
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self, system):
1616
self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY")
1717
self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
1818
self.cohere_api_key = os.environ.get("COHERE_API_KEY")
19+
self.azure_api_key = os.environ.get("AZURE_API_KEY")
1920

2021
@override
2122
def get_model(
@@ -26,6 +27,8 @@ def get_model(
2627
api_base: str | None = None, # noqa: ARG002
2728
**kwargs: Any
2829
) -> Any:
30+
if self.system.settings['azure_api_key'] != None:
31+
model_family = 'azure'
2932
if database_connection.llm_api_key is not None:
3033
fernet_encrypt = FernetEncrypt()
3134
api_key = fernet_encrypt.decrypt(database_connection.llm_api_key)
@@ -35,6 +38,8 @@ def get_model(
3538
self.anthropic_api_key = api_key
3639
elif model_family == "google":
3740
self.google_api_key = api_key
41+
elif model_family == "azure":
42+
self.azure_api_key == api_key
3843
if self.openai_api_key:
3944
self.model = OpenAI(model_name=model_name, **kwargs)
4045
elif self.aleph_alpha_api_key:
@@ -43,6 +48,8 @@ def get_model(
4348
self.model = Anthropic(model=model_name, **kwargs)
4449
elif self.cohere_api_key:
4550
self.model = Cohere(model=model_name, **kwargs)
51+
elif self.azure_api_key:
52+
self.model = AzureOpenAI(model=model_name, **kwargs)
4653
else:
4754
raise ValueError("No valid API key environment variable found")
4855
return self.model

Diff for: dataherald/model/chat_model.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22

33
from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm
4-
from langchain_openai import ChatOpenAI
4+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
55
from overrides import override
66

77
from dataherald.model import LLMModel
@@ -22,6 +22,18 @@ def get_model(
2222
**kwargs: Any
2323
) -> Any:
2424
api_key = database_connection.decrypt_api_key()
25+
if self.system.settings['azure_api_key'] != None:
26+
model_family = 'azure'
27+
if model_family == "azure":
28+
if api_base.endswith("/"): #TODO check where final "/" is added to api_base
29+
api_base = api_base[:-1]
30+
return AzureChatOpenAI(
31+
deployment_name=model_name,
32+
openai_api_key=api_key,
33+
azure_endpoint= api_base,
34+
api_version=self.system.settings['azure_api_version'],
35+
**kwargs
36+
)
2537
if model_family == "openai":
2638
return ChatOpenAI(
2739
model_name=model_name,

Diff for: dataherald/sql_generator/dataherald_finetuning_agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def generate_response(
605605
use_finetuned_model_only=self.use_fintuned_model_only,
606606
model_name=finetuning.base_llm.model_name,
607607
openai_fine_tuning=openai_fine_tuning,
608-
embedding=OpenAIEmbeddings(
608+
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
609609
openai_api_key=database_connection.decrypt_api_key(),
610610
model=EMBEDDING_MODEL,
611611
),
@@ -710,7 +710,7 @@ def stream_response(
710710
use_finetuned_model_only=self.use_fintuned_model_only,
711711
model_name=finetuning.base_llm.model_name,
712712
openai_fine_tuning=openai_fine_tuning,
713-
embedding=OpenAIEmbeddings(
713+
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
714714
openai_api_key=database_connection.decrypt_api_key(),
715715
model=EMBEDDING_MODEL,
716716
),

Diff for: dataherald/sql_generator/dataherald_sqlagent.py

+56-26
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from langchain.chains.llm import LLMChain
2323
from langchain.tools.base import BaseTool
2424
from langchain_community.callbacks import get_openai_callback
25-
from langchain_openai import OpenAIEmbeddings
25+
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
2626
from overrides import override
2727
from pydantic import BaseModel, Field
2828
from sql_metadata import Parser
@@ -753,18 +753,33 @@ def generate_response(
753753
number_of_samples = 0
754754
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
755755
self.database = SQLDatabase.get_sql_engine(database_connection)
756-
toolkit = SQLDatabaseToolkit(
757-
db=self.database,
758-
context=context,
759-
few_shot_examples=new_fewshot_examples,
760-
instructions=instructions,
761-
is_multiple_schema=True if user_prompt.schemas else False,
762-
db_scan=db_scan,
763-
embedding=OpenAIEmbeddings(
764-
openai_api_key=database_connection.decrypt_api_key(),
765-
model=EMBEDDING_MODEL,
766-
),
767-
)
756+
#Set Embeddings class depending on azure / not azure
757+
if self.llm.openai_api_type == "azure":
758+
toolkit = SQLDatabaseToolkit(
759+
db=self.database,
760+
context=context,
761+
few_shot_examples=new_fewshot_examples,
762+
instructions=instructions,
763+
is_multiple_schema=True if user_prompt.schemas else False,
764+
db_scan=db_scan,
765+
embedding=AzureOpenAIEmbeddings(
766+
openai_api_key=database_connection.decrypt_api_key(),
767+
model=EMBEDDING_MODEL,
768+
),
769+
)
770+
else:
771+
toolkit = SQLDatabaseToolkit(
772+
db=self.database,
773+
context=context,
774+
few_shot_examples=new_fewshot_examples,
775+
instructions=instructions,
776+
is_multiple_schema=True if user_prompt.schemas else False,
777+
db_scan=db_scan,
778+
embedding=OpenAIEmbeddings(
779+
openai_api_key=database_connection.decrypt_api_key(),
780+
model=EMBEDDING_MODEL,
781+
),
782+
)
768783
agent_executor = self.create_sql_agent(
769784
toolkit=toolkit,
770785
verbose=True,
@@ -858,19 +873,34 @@ def stream_response(
858873
new_fewshot_examples = None
859874
number_of_samples = 0
860875
self.database = SQLDatabase.get_sql_engine(database_connection)
861-
toolkit = SQLDatabaseToolkit(
862-
queuer=queue,
863-
db=self.database,
864-
context=[{}],
865-
few_shot_examples=new_fewshot_examples,
866-
instructions=instructions,
867-
is_multiple_schema=True if user_prompt.schemas else False,
868-
db_scan=db_scan,
869-
embedding=OpenAIEmbeddings(
870-
openai_api_key=database_connection.decrypt_api_key(),
871-
model=EMBEDDING_MODEL,
872-
),
873-
)
876+
#Set Embeddings class depending on azure / not azure
877+
if self.llm.openai_api_type == "azure":
878+
toolkit = SQLDatabaseToolkit(
879+
db=self.database,
880+
context=context,
881+
few_shot_examples=new_fewshot_examples,
882+
instructions=instructions,
883+
is_multiple_schema=True if user_prompt.schemas else False,
884+
db_scan=db_scan,
885+
embedding=AzureOpenAIEmbeddings(
886+
openai_api_key=database_connection.decrypt_api_key(),
887+
model=EMBEDDING_MODEL,
888+
),
889+
)
890+
else:
891+
toolkit = SQLDatabaseToolkit(
892+
queuer=queue,
893+
db=self.database,
894+
context=[{}],
895+
few_shot_examples=new_fewshot_examples,
896+
instructions=instructions,
897+
is_multiple_schema=True if user_prompt.schemas else False,
898+
db_scan=db_scan,
899+
embedding=OpenAIEmbeddings(
900+
openai_api_key=database_connection.decrypt_api_key(),
901+
model=EMBEDDING_MODEL,
902+
),
903+
)
874904
agent_executor = self.create_sql_agent(
875905
toolkit=toolkit,
876906
verbose=True,

Diff for: docker-compose.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ services:
2424
ports:
2525
- 27017:27017
2626
volumes:
27-
- ./initdb.d/:/docker-entrypoint-initdb.d/
27+
- ./initdb.d/init-mongo.sh:/docker-entrypoint-initdb.d/init-mongo.sh:ro
2828
- ./dbdata/mongo_data/data:/data/db/
2929
- ./dbdata/mongo_data/db_config:/data/configdb/
3030
environment:

Diff for: initdb.d/init-mongo.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set -e
1+
# set -e
22

33
mongosh <<EOF
44
use $MONGO_INITDB_DATABASE

0 commit comments

Comments
 (0)