17
17
18
18
from google .generativeai import types
19
19
import google .generativeai as genai
20
-
20
+ from sqlalchemy import create_engine
21
21
22
22
from .npc_sysenv import (
23
23
get_system_message ,
@@ -1554,7 +1554,7 @@ def check_output_sufficient(
1554
1554
1555
1555
def process_data_output (
1556
1556
llm_response : Dict [str , Any ],
1557
- db_conn : sqlite3 . Connection ,
1557
+ db_conn ,
1558
1558
request : str ,
1559
1559
tables : str = None ,
1560
1560
history : str = None ,
@@ -1572,9 +1572,15 @@ def process_data_output(
1572
1572
if not query :
1573
1573
return {"response" : "No query provided" , "code" : 400 }
1574
1574
1575
+ # Create SQLAlchemy engine based on connection type
1576
+ if "psycopg2" in db_conn .__class__ .__module__ :
1577
+ engine = create_engine ("postgresql://caug:gobears@localhost/npc_test" )
1578
+ else :
1579
+ engine = create_engine ("sqlite:///test_sqlite.db" )
1580
+
1575
1581
if choice == 1 : # Direct answer query
1576
1582
try :
1577
- df = pd .read_sql_query (query , db_conn )
1583
+ df = pd .read_sql_query (query , engine )
1578
1584
result = check_output_sufficient (
1579
1585
request , df , query , model = model , provider = provider , npc = npc
1580
1586
)
@@ -1591,7 +1597,7 @@ def process_data_output(
1591
1597
1592
1598
elif choice == 2 : # Exploratory query
1593
1599
try :
1594
- df = pd .read_sql_query (query , db_conn )
1600
+ df = pd .read_sql_query (query , engine )
1595
1601
extra_context = f"""
1596
1602
Exploratory query results:
1597
1603
Query: { query }
@@ -1621,7 +1627,7 @@ def process_data_output(
1621
1627
1622
1628
def get_data_response (
1623
1629
request : str ,
1624
- db_conn : sqlite3 . Connection ,
1630
+ db_conn ,
1625
1631
tables : str = None ,
1626
1632
n_try_freq : int = 5 ,
1627
1633
extra_context : str = None ,
@@ -1634,9 +1640,73 @@ def get_data_response(
1634
1640
"""
1635
1641
Generate a response to a data request, with retries for failed attempts.
1636
1642
"""
1643
+
1644
+ # Extract schema information based on connection type
1645
+ schema_info = ""
1646
+ if "psycopg2" in db_conn .__class__ .__module__ :
1647
+ cursor = db_conn .cursor ()
1648
+ # Get all tables and their columns
1649
+ cursor .execute (
1650
+ """
1651
+ SELECT
1652
+ t.table_name,
1653
+ array_agg(c.column_name || ' ' || c.data_type) as columns,
1654
+ array_agg(
1655
+ CASE
1656
+ WHEN tc.constraint_type = 'FOREIGN KEY'
1657
+ THEN kcu.column_name || ' REFERENCES ' || ccu.table_name || '.' || ccu.column_name
1658
+ ELSE NULL
1659
+ END
1660
+ ) as foreign_keys
1661
+ FROM information_schema.tables t
1662
+ JOIN information_schema.columns c ON t.table_name = c.table_name
1663
+ LEFT JOIN information_schema.table_constraints tc
1664
+ ON t.table_name = tc.table_name
1665
+ AND tc.constraint_type = 'FOREIGN KEY'
1666
+ LEFT JOIN information_schema.key_column_usage kcu
1667
+ ON tc.constraint_name = kcu.constraint_name
1668
+ LEFT JOIN information_schema.constraint_column_usage ccu
1669
+ ON tc.constraint_name = ccu.constraint_name
1670
+ WHERE t.table_schema = 'public'
1671
+ GROUP BY t.table_name;
1672
+ """
1673
+ )
1674
+ for table , columns , fks in cursor .fetchall ():
1675
+ schema_info += f"\n Table { table } :\n "
1676
+ schema_info += "Columns:\n "
1677
+ for col in columns :
1678
+ schema_info += f" - { col } \n "
1679
+ if any (fk for fk in fks if fk is not None ):
1680
+ schema_info += "Foreign Keys:\n "
1681
+ for fk in fks :
1682
+ if fk :
1683
+ schema_info += f" - { fk } \n "
1684
+
1685
+ elif "sqlite3" in db_conn .__class__ .__module__ :
1686
+ cursor = db_conn .cursor ()
1687
+ cursor .execute ("SELECT name FROM sqlite_master WHERE type='table';" )
1688
+ tables = cursor .fetchall ()
1689
+ for (table_name ,) in tables :
1690
+ schema_info += f"\n Table { table_name } :\n "
1691
+ cursor .execute (f"PRAGMA table_info({ table_name } );" )
1692
+ columns = cursor .fetchall ()
1693
+ schema_info += "Columns:\n "
1694
+ for col in columns :
1695
+ schema_info += f" - { col [1 ]} { col [2 ]} \n "
1696
+
1697
+ cursor .execute (f"PRAGMA foreign_key_list({ table_name } );" )
1698
+ foreign_keys = cursor .fetchall ()
1699
+ if foreign_keys :
1700
+ schema_info += "Foreign Keys:\n "
1701
+ for fk in foreign_keys :
1702
+ schema_info += f" - { fk [3 ]} REFERENCES { fk [2 ]} ({ fk [4 ]} )\n "
1703
+
1637
1704
prompt = f"""
1638
1705
User request: { request }
1639
- Available tables: { tables or 'Not specified' }
1706
+
1707
+ Database Schema:
1708
+ { schema_info }
1709
+
1640
1710
{ extra_context or '' }
1641
1711
{ f'Query history: { history } ' if history else '' }
1642
1712
@@ -1655,49 +1725,47 @@ def get_data_response(
1655
1725
1656
1726
failures = []
1657
1727
for attempt in range (max_retries ):
1658
- try :
1659
- llm_response = get_llm_response (
1660
- prompt , npc = npc , format = "json" , model = model , provider = provider
1661
- )
1728
+ # try:
1729
+ llm_response = get_llm_response (
1730
+ prompt , npc = npc , format = "json" , model = model , provider = provider
1731
+ )
1662
1732
1663
- # Clean response if it's a string
1664
- response_data = llm_response .get ("response" , {})
1665
- if isinstance (response_data , str ):
1666
- response_data = (
1667
- response_data .replace ("```json" , "" ).replace ("```" , "" ).strip ()
1668
- )
1669
- try :
1670
- response_data = json .loads (response_data )
1671
- except json .JSONDecodeError :
1672
- failures .append ("Invalid JSON response" )
1673
- continue
1674
-
1675
- result = process_data_output (
1676
- response_data ,
1677
- db_conn ,
1678
- request ,
1679
- tables = tables ,
1680
- history = failures ,
1681
- npc = npc ,
1682
- model = model ,
1683
- provider = provider ,
1733
+ # Clean response if it's a string
1734
+ response_data = llm_response .get ("response" , {})
1735
+ if isinstance (response_data , str ):
1736
+ response_data = (
1737
+ response_data .replace ("```json" , "" ).replace ("```" , "" ).strip ()
1684
1738
)
1739
+ try :
1740
+ response_data = json .loads (response_data )
1741
+ except json .JSONDecodeError :
1742
+ failures .append ("Invalid JSON response" )
1743
+ continue
1744
+
1745
+ result = process_data_output (
1746
+ response_data ,
1747
+ db_conn ,
1748
+ request ,
1749
+ tables = tables ,
1750
+ history = failures ,
1751
+ npc = npc ,
1752
+ model = model ,
1753
+ provider = provider ,
1754
+ )
1685
1755
1686
- if result ["code" ] == 200 :
1687
- return result
1688
-
1689
- failures .append (result ["response" ])
1756
+ if result ["code" ] == 200 :
1757
+ return result
1690
1758
1691
- if attempt == max_retries - 1 :
1692
- return {
1693
- "response" : f"Failed after { max_retries } attempts. Errors: { '; ' .join (failures )} " ,
1694
- "code" : 400 ,
1695
- }
1759
+ failures .append (result ["response" ])
1696
1760
1697
- except Exception as e :
1698
- failures .append (str (e ))
1761
+ if attempt == max_retries - 1 :
1762
+ return {
1763
+ "response" : f"Failed after { max_retries } attempts. Errors: { '; ' .join (failures )} " ,
1764
+ "code" : 400 ,
1765
+ }
1699
1766
1700
- return {"response" : "Max retries exceeded" , "code" : 400 }
1767
+ # except Exception as e:
1768
+ # failures.append(str(e))
1701
1769
1702
1770
1703
1771
def enter_reasoning_human_in_the_loop (
0 commit comments