Skip to content

Commit e35a7b8

Browse files
authoredSep 24, 2024
Add Skip Straight to SQL Query Execution Option and Include REsults (#34)
* Add option to pre-run the first option in the query cache. * Speed up existing approach for vector and cache * Add experimentation data * Improve write up
1 parent 8073265 commit e35a7b8

21 files changed

+1613
-2048
lines changed
 

‎.devcontainer/devcontainer.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
"ms-python.python",
1010
"ms-python.vscode-pylance",
1111
"ms-toolsai.jupyter",
12-
"ms-azuretools.vscode-docker"
12+
"ms-azuretools.vscode-docker",
13+
"ms-azuretools.vscode-azurefunctions"
1314
],
1415
"settings": {
1516
"editor.formatOnSave": true

‎deploy_ai_search/deploy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from rag_documents import RagDocumentsAISearch
55
from text_2_sql import Text2SqlAISearch
66
from text_2_sql_query_cache import Text2SqlQueryCacheAISearch
7+
import logging
8+
9+
logging.basicConfig(level=logging.INFO)
710

811

912
def deploy_config(arguments: argparse.Namespace):

‎deploy_ai_search/text_2_sql.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def get_index_fields(self) -> list[SearchableField]:
9393
type=SearchFieldDataType.String,
9494
collection=True,
9595
hidden=True,
96-
), # This is needed to enable semantic searching against the column names as complex field types are not used.
96+
# This is needed to enable semantic searching against the column names as complex field types are not used.
97+
),
9798
SimpleField(
9899
name="DateLastModified",
99100
type=SearchFieldDataType.DateTimeOffset,
@@ -213,7 +214,8 @@ def get_indexer(self) -> SearchIndexer:
213214
target_field_name="ColumnNames",
214215
),
215216
FieldMapping(
216-
name="DateLastModified", source="/document/DateLastModified"
217+
source_field_name="/document/DateLastModified",
218+
target_field_name="DateLastModified",
217219
),
218220
],
219221
parameters=indexer_parameters,

‎deploy_ai_search/text_2_sql_query_cache.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,6 @@ def get_index_fields(self) -> list[SearchableField]:
5555
SearchableField(
5656
name="Query", type=SearchFieldDataType.String, filterable=True
5757
),
58-
SearchField(
59-
name="QueryEmbedding",
60-
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
61-
vector_search_dimensions=self.environment.open_ai_embedding_dimensions,
62-
vector_search_profile_name=self.vector_search_profile_name,
63-
),
6458
ComplexField(
6559
name="Schemas",
6660
collection=True,

‎text_2_sql/.env

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ OpenAI__ApiKey=<openAIKey if using non managed identity>
55
OpenAI__ApiVersion=<openAIApiVersion>
66
Text2Sql__DatabaseEngine=<databaseEngine>
77
Text2Sql__UseQueryCache=<whether to use the query cache first or not>
8+
Text2Sql__PreRunQueryCache=<whether to pre-run the top result from the query cache or not>
89
Text2Sql__DatabaseName=<databaseName>
910
Text2Sql__DatabaseConnectionString=<databaseConnectionString>
1011
AIService__AzureSearchOptions__Endpoint=<searchServiceEndpoint>

‎text_2_sql/README.md

Lines changed: 73 additions & 18 deletions
Large diffs are not rendered by default.
Loading
6.57 KB
Loading
73.8 KB
Loading

‎text_2_sql/plugins/ai_search_plugin/ai_search_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
import json
77
import logging
8-
from ai_search import run_ai_search_query
8+
from utils.ai_search import run_ai_search_query
99

1010

1111
class AISearchPlugin:

‎text_2_sql/plugins/prompt_based_sql_plugin/prompt_based_sql_plugin.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from semantic_kernel.functions import kernel_function
4-
import aioodbc
54
from typing import Annotated
6-
import os
75
import json
86
import logging
7+
import os
8+
import aioodbc
99

1010

1111
class PromptBasedSQLPlugin:
@@ -32,7 +32,8 @@ def load_entities(self):
3232
for entity_object in entities:
3333
entity = entity_object["Entity"]
3434
entity_object["SelectFromEntity"] = f"{self.database}.{entity}"
35-
self.entities[entity.lower()] = entity_object
35+
entity_name = entity_object["EntityName"].lower()
36+
self.entities[entity_name] = entity_object
3637

3738
def system_prompt(self, engine_specific_rules: str | None = None) -> str:
3839
"""Get the schemas for the database entities and provide a system prompt for the user.
@@ -54,9 +55,14 @@ def system_prompt(self, engine_specific_rules: str | None = None) -> str:
5455
entity_descriptions = "\n\n ".join(entity_descriptions)
5556

5657
if engine_specific_rules:
57-
engine_specific_rules = f"\n The following {self.target_engine} Syntax rules must be adhered to.\n {engine_specific_rules}"
58+
engine_specific_rules = f"""\n The following {
59+
self.target_engine} Syntax rules must be adhered to.\n {engine_specific_rules}"""
60+
61+
system_prompt = f"""Use the names and descriptions of {self.target_engine} entities provided in ENTITIES LIST to decide which entities to query if you need to retrieve information from the database. Use the 'GetEntitySchema()' function to get more details of the schema of the view you want to query.
62+
63+
Always then use the 'RunSQLQuery()' function to run the SQL query against the database. Never just return the SQL query as the answer.
5864
59-
system_prompt = f"""Use the names and descriptions of {self.target_engine} entities provided in ENTITIES LIST to decide which entities to query if you need to retrieve information from the database. Use the 'GetEntitySchema()' function to get more details of the schema of the view you want to query. Use the 'RunSQLQuery()' function to run the SQL query against the database.
65+
Do not give the steps to the user in the response. Make sure to execute the SQL query and return the results in the response.
6066
6167
You must always examine the provided {self.target_engine} entity descriptions to determine if they can answer the question.
6268
@@ -111,7 +117,7 @@ async def get_entity_schema(
111117
return json.dumps({entity_name: self.entities[entity_name.lower()]})
112118

113119
@kernel_function(
114-
description="Runs an SQL query against the SQL Database to extract information.",
120+
description="Runs an SQL query against the SQL Database to extract information. This function must always be used during the answer generation process. Do not just return the SQL query as the answer.",
115121
name="RunSQLQuery",
116122
)
117123
async def run_sql_query(

0 commit comments

Comments
 (0)
Failed to load comments.