@@ -190,12 +190,20 @@ def similart_tables_based_on_few_shot_examples(self, df: pd.DataFrame) -> List[s
190
190
tables = Parser (example ["sql" ]).tables
191
191
except Exception as e :
192
192
logger .error (f"Error parsing SQL: { str (e )} " )
193
- most_similar_tables .update (tables )
194
- df .drop (df [df .table_name .isin (most_similar_tables )].index , inplace = True )
193
+ for table in tables :
194
+ found_tables = df [df .table_name == table ]
195
+ for _ , row in found_tables .iterrows ():
196
+ most_similar_tables .add ((row ["schema_name" ], row ["table_name" ]))
197
+ df .drop (
198
+ df [
199
+ df .table_name .isin ([table [1 ] for table in most_similar_tables ])
200
+ ].index ,
201
+ inplace = True ,
202
+ )
195
203
return most_similar_tables
196
204
197
205
@catch_exceptions ()
198
- def _run (
206
+ def _run ( # noqa: PLR0912
199
207
self ,
200
208
user_question : str ,
201
209
run_manager : CallbackManagerForToolRun | None = None , # noqa: ARG002
@@ -214,9 +222,12 @@ def _run(
214
222
table_rep = f"Table { table .table_name } contain columns: [{ col_rep } ], this tables has: { table .description } "
215
223
else :
216
224
table_rep = f"Table { table .table_name } contain columns: [{ col_rep } ]"
217
- table_representations .append ([table .table_name , table_rep ])
225
+ table_representations .append (
226
+ [table .schema_name , table .table_name , table_rep ]
227
+ )
218
228
df = pd .DataFrame (
219
- table_representations , columns = ["table_name" , "table_representation" ]
229
+ table_representations ,
230
+ columns = ["schema_name" , "table_name" , "table_representation" ],
220
231
)
221
232
df ["table_embedding" ] = self .get_docs_embedding (df .table_representation )
222
233
df ["similarities" ] = df .table_embedding .apply (
@@ -227,12 +238,20 @@ def _run(
227
238
most_similar_tables = self .similart_tables_based_on_few_shot_examples (df )
228
239
table_relevance = ""
229
240
for _ , row in df .iterrows ():
230
- table_relevance += f'Table: `{ row ["table_name" ]} `, relevance score: { row ["similarities" ]} \n '
241
+ if row ["schema_name" ] is not None :
242
+ table_name = row ["schema_name" ] + "." + row ["table_name" ]
243
+ else :
244
+ table_name = row ["table_name" ]
245
+ table_relevance += (
246
+ f'Table: `{ table_name } `, relevance score: { row ["similarities" ]} \n '
247
+ )
231
248
if len (most_similar_tables ) > 0 :
232
249
for table in most_similar_tables :
233
- table_relevance += (
234
- f"Table: `{ table } `, relevance score: { max (df ['similarities' ])} \n "
235
- )
250
+ if table [0 ] is not None :
251
+ table_name = table [0 ] + "." + table [1 ]
252
+ else :
253
+ table_name = table [1 ]
254
+ table_relevance += f"Table: `{ table_name } `, relevance score: { max (df ['similarities' ])} \n "
236
255
return table_relevance
237
256
238
257
async def _arun (
@@ -358,27 +377,32 @@ class SchemaSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
358
377
db_scan : List [TableDescription ]
359
378
360
379
@catch_exceptions ()
361
- def _run (
380
+ def _run ( # noqa: C901
362
381
self ,
363
382
table_names : str ,
364
383
run_manager : CallbackManagerForToolRun | None = None , # noqa: ARG002
365
384
) -> str :
366
385
"""Get the schema for tables in a comma-separated list."""
367
386
table_names_list = table_names .split (", " )
368
- table_names_list = [
369
- replace_unprocessable_characters (table_name )
370
- for table_name in table_names_list
371
- ]
387
+ processed_table_names = []
388
+ for table in table_names_list :
389
+ formatted_table = replace_unprocessable_characters (table )
390
+ if "." in formatted_table :
391
+ processed_table_names .append (formatted_table .split ("." )[1 ])
392
+ else :
393
+ processed_table_names .append (formatted_table )
372
394
tables_schema = ""
373
395
for table in self .db_scan :
374
- if table .table_name in table_names_list :
396
+ if table .table_name in processed_table_names :
375
397
tables_schema += "```sql\n "
376
398
tables_schema += table .table_schema + "\n "
377
399
descriptions = []
378
400
if table .description is not None :
379
- descriptions .append (
380
- f"Table `{ table .table_name } `: { table .description } \n "
381
- )
401
+ if table .schema_name :
402
+ table_name = f"{ table .schema_name } .{ table .table_name } "
403
+ else :
404
+ table_name = table .table_name
405
+ descriptions .append (f"Table `{ table_name } `: { table .description } \n " )
382
406
for column in table .columns :
383
407
if column .description is not None :
384
408
descriptions .append (
@@ -555,6 +579,9 @@ def generate_response(
555
579
)
556
580
if not db_scan :
557
581
raise ValueError ("No scanned tables found for database" )
582
+ db_scan = SQLGenerator .filter_tables_by_schema (
583
+ db_scan = db_scan , prompt = user_prompt
584
+ )
558
585
few_shot_examples , instructions = context_store .retrieve_context_for_question (
559
586
user_prompt , number_of_samples = 5
560
587
)
@@ -658,6 +685,9 @@ def stream_response(
658
685
)
659
686
if not db_scan :
660
687
raise ValueError ("No scanned tables found for database" )
688
+ db_scan = SQLGenerator .filter_tables_by_schema (
689
+ db_scan = db_scan , prompt = user_prompt
690
+ )
661
691
_ , instructions = context_store .retrieve_context_for_question (
662
692
user_prompt , number_of_samples = 1
663
693
)
0 commit comments