21
21
from langchain .chains .llm import LLMChain
22
22
from langchain .tools .base import BaseTool
23
23
from langchain_community .callbacks import get_openai_callback
24
- from langchain_openai import OpenAIEmbeddings
24
+ from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
25
25
from openai import OpenAI
26
26
from overrides import override
27
27
from pydantic import BaseModel , Field
@@ -587,14 +587,24 @@ def generate_response(
587
587
)
588
588
finetunings_repository = FinetuningsRepository (storage )
589
589
finetuning = finetunings_repository .find_by_id (self .finetuning_id )
590
- openai_fine_tuning = OpenAIFineTuning (storage , finetuning )
590
+ openai_fine_tuning = OpenAIFineTuning (self . system , storage , finetuning )
591
591
finetuning = openai_fine_tuning .retrieve_finetuning_job ()
592
592
if finetuning .status != FineTuningStatus .SUCCEEDED .value :
593
593
raise FinetuningNotAvailableError (
594
594
f"Finetuning({ self .finetuning_id } ) has the status { finetuning .status } ."
595
595
f"Finetuning should have the status { FineTuningStatus .SUCCEEDED .value } to generate SQL queries."
596
596
)
597
597
self .database = SQLDatabase .get_sql_engine (database_connection )
598
+ if self .system .settings ["azure_api_key" ] is not None :
599
+ embedding = AzureOpenAIEmbeddings (
600
+ openai_api_key = database_connection .decrypt_api_key (),
601
+ model = EMBEDDING_MODEL ,
602
+ )
603
+ else :
604
+ embedding = OpenAIEmbeddings (
605
+ openai_api_key = database_connection .decrypt_api_key (),
606
+ model = EMBEDDING_MODEL ,
607
+ )
598
608
toolkit = SQLDatabaseToolkit (
599
609
db = self .database ,
600
610
instructions = instructions ,
@@ -605,10 +615,7 @@ def generate_response(
605
615
use_finetuned_model_only = self .use_fintuned_model_only ,
606
616
model_name = finetuning .base_llm .model_name ,
607
617
openai_fine_tuning = openai_fine_tuning ,
608
- embedding = OpenAIEmbeddings ( #TODO AzureOpenAIEmbeddings when Azure
609
- openai_api_key = database_connection .decrypt_api_key (),
610
- model = EMBEDDING_MODEL ,
611
- ),
618
+ embedding = embedding ,
612
619
)
613
620
agent_executor = self .create_sql_agent (
614
621
toolkit = toolkit ,
@@ -693,14 +700,24 @@ def stream_response(
693
700
)
694
701
finetunings_repository = FinetuningsRepository (storage )
695
702
finetuning = finetunings_repository .find_by_id (self .finetuning_id )
696
- openai_fine_tuning = OpenAIFineTuning (storage , finetuning )
703
+ openai_fine_tuning = OpenAIFineTuning (self . system , storage , finetuning )
697
704
finetuning = openai_fine_tuning .retrieve_finetuning_job ()
698
705
if finetuning .status != FineTuningStatus .SUCCEEDED .value :
699
706
raise FinetuningNotAvailableError (
700
707
f"Finetuning({ self .finetuning_id } ) has the status { finetuning .status } ."
701
708
f"Finetuning should have the status { FineTuningStatus .SUCCEEDED .value } to generate SQL queries."
702
709
)
703
710
self .database = SQLDatabase .get_sql_engine (database_connection )
711
+ if self .system .settings ["azure_api_key" ] is not None :
712
+ embedding = AzureOpenAIEmbeddings (
713
+ openai_api_key = database_connection .decrypt_api_key (),
714
+ model = EMBEDDING_MODEL ,
715
+ )
716
+ else :
717
+ embedding = OpenAIEmbeddings (
718
+ openai_api_key = database_connection .decrypt_api_key (),
719
+ model = EMBEDDING_MODEL ,
720
+ )
704
721
toolkit = SQLDatabaseToolkit (
705
722
db = self .database ,
706
723
instructions = instructions ,
@@ -710,10 +727,7 @@ def stream_response(
710
727
use_finetuned_model_only = self .use_fintuned_model_only ,
711
728
model_name = finetuning .base_llm .model_name ,
712
729
openai_fine_tuning = openai_fine_tuning ,
713
- embedding = OpenAIEmbeddings ( #TODO AzureOpenAIEmbeddings when Azure
714
- openai_api_key = database_connection .decrypt_api_key (),
715
- model = EMBEDDING_MODEL ,
716
- ),
730
+ embedding = embedding ,
717
731
)
718
732
agent_executor = self .create_sql_agent (
719
733
toolkit = toolkit ,
0 commit comments