Skip to content

Commit dc4e743

Browse files
DH-5776/fixing the bug with Azure OpenAI (#481)
* DH-5776/fixing the bug with Azure OpenAI * DH-5776/ignore linter * DH-5776/ignore linter * DH-5776/fixing black
1 parent 2537ca5 commit dc4e743

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

Diff for: dataherald/services/sql_generations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ def update_the_initial_sql_generation(
6363
initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps
6464
return self.sql_generation_repository.update(initial_sql_generation)
6565

66-
def create(
66+
def create( # noqa: PLR0912
6767
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
68-
) -> SQLGeneration:
68+
) -> SQLGeneration: # noqa: PLR0912
6969
initial_sql_generation = SQLGeneration(
7070
prompt_id=prompt_id,
7171
created_at=datetime.now(),

Diff for: dataherald/sql_generator/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,15 @@ def generate_response(
179179
"""Generates a response to a user question."""
180180
pass
181181

182-
def stream_agent_steps( # noqa: C901
182+
def stream_agent_steps( # noqa: PLR0912, C901
183183
self,
184184
question: str,
185185
agent_executor: AgentExecutor,
186186
response: SQLGeneration,
187187
sql_generation_repository: SQLGenerationRepository,
188188
queue: Queue,
189189
metadata: dict = None,
190-
):
190+
): # noqa: PLR0912
191191
try:
192192
with get_openai_callback() as cb:
193193
for chunk in agent_executor.stream(

Diff for: dataherald/sql_generator/dataherald_finetuning_agent.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,7 @@ def generate_response(
595595
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
596596
)
597597
self.database = SQLDatabase.get_sql_engine(database_connection)
598-
if self.llm.openai_api_type == "azure":
598+
if self.system.settings["azure_api_key"] is not None:
599599
embedding = AzureOpenAIEmbeddings(
600600
openai_api_key=database_connection.decrypt_api_key(),
601601
model=EMBEDDING_MODEL,
@@ -708,7 +708,7 @@ def stream_response(
708708
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
709709
)
710710
self.database = SQLDatabase.get_sql_engine(database_connection)
711-
if self.llm.openai_api_type == "azure":
711+
if self.system.settings["azure_api_key"] is not None:
712712
embedding = AzureOpenAIEmbeddings(
713713
openai_api_key=database_connection.decrypt_api_key(),
714714
model=EMBEDDING_MODEL,

Diff for: dataherald/sql_generator/dataherald_sqlagent.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -710,13 +710,13 @@ def create_sql_agent(
710710
)
711711

712712
@override
713-
def generate_response(
713+
def generate_response( # noqa: PLR0912
714714
self,
715715
user_prompt: Prompt,
716716
database_connection: DatabaseConnection,
717717
context: List[dict] = None,
718718
metadata: dict = None,
719-
) -> SQLGeneration:
719+
) -> SQLGeneration: # noqa: PLR0912
720720
context_store = self.system.instance(ContextStore)
721721
storage = self.system.instance(DB)
722722
response = SQLGeneration(
@@ -754,7 +754,7 @@ def generate_response(
754754
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
755755
self.database = SQLDatabase.get_sql_engine(database_connection)
756756
# Set Embeddings class depending on azure / not azure
757-
if self.llm.openai_api_type == "azure":
757+
if self.system.settings["azure_api_key"] is not None:
758758
toolkit = SQLDatabaseToolkit(
759759
db=self.database,
760760
context=context,
@@ -874,7 +874,7 @@ def stream_response(
874874
number_of_samples = 0
875875
self.database = SQLDatabase.get_sql_engine(database_connection)
876876
# Set Embeddings class depending on azure / not azure
877-
if self.llm.openai_api_type == "azure":
877+
if self.system.settings["azure_api_key"] is not None:
878878
embedding = AzureOpenAIEmbeddings(
879879
openai_api_key=database_connection.decrypt_api_key(),
880880
model=EMBEDDING_MODEL,

0 commit comments

Comments
 (0)