Skip to content

Commit 0b28714

Browse files
authored
Merge pull request #137 from cagostino/chris/installation_target_testing
added new installation targets and accommodated post gres conns with NPCs and in get_data_response.
2 parents 34d0a51 + 9532403 commit 0b28714

12 files changed

+711
-102
lines changed

README.md

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,11 +399,16 @@ ollama pull llama3.2
399399
ollama pull llava:7b
400400
ollama pull nomic-embed-text
401401
pip install npcsh
402-
```
403-
If you'd like to install the abilities to use STT and TTS, additionall install the following
404-
```
405-
pip install openai-whisper pyaudio gtts playsound
406-
```
402+
# if you want to install with the API libraries
403+
pip install npcsh[lite]
404+
# if you want the full local package set up (ollama, diffusers, transformers, cuda etc.)
405+
pip install npcsh[local]
406+
# if you want to use tts/stt
407+
pip install npcsh[whisper]
408+
409+
# if you want everything:
410+
pip install npcsh[all]
411+
407412

408413

409414

@@ -418,6 +423,16 @@ ollama pull llama3.2
418423
ollama pull llava:7b
419424
ollama pull nomic-embed-text
420425
pip install npcsh
426+
# if you want to install with the API libraries
427+
pip install npcsh[lite]
428+
# if you want the full local package set up (ollama, diffusers, transformers, cuda etc.)
429+
pip install npcsh[local]
430+
# if you want to use tts/stt
431+
pip install npcsh[whisper]
432+
433+
# if you want everything:
434+
pip install npcsh[all]
435+
421436
```
422437
### Windows Install
423438

@@ -430,6 +445,16 @@ ollama pull llama3.2
430445
ollama pull llava:7b
431446
ollama pull nomic-embed-text
432447
pip install npcsh
448+
# if you want to install with the API libraries
449+
pip install npcsh[lite]
450+
# if you want the full local package set up (ollama, diffusers, transformers, cuda etc.)
451+
pip install npcsh[local]
452+
# if you want to use tts/stt
453+
pip install npcsh[whisper]
454+
455+
# if you want everything:
456+
pip install npcsh[all]
457+
433458
```
434459
As of now, npcsh appears to work well with some of the core functionalities like /ots and /whisper.
435460

npcsh/helpers.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import subprocess
66
import platform
77
import yaml
8-
import nltk
8+
9+
try:
10+
import nltk
11+
except:
12+
print("Error importing nltk")
913
import numpy as np
1014

1115
import filecmp

npcsh/llm_funcs.py

Lines changed: 111 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from google.generativeai import types
1919
import google.generativeai as genai
20-
20+
from sqlalchemy import create_engine
2121

2222
from .npc_sysenv import (
2323
get_system_message,
@@ -1554,7 +1554,7 @@ def check_output_sufficient(
15541554

15551555
def process_data_output(
15561556
llm_response: Dict[str, Any],
1557-
db_conn: sqlite3.Connection,
1557+
db_conn,
15581558
request: str,
15591559
tables: str = None,
15601560
history: str = None,
@@ -1572,9 +1572,15 @@ def process_data_output(
15721572
if not query:
15731573
return {"response": "No query provided", "code": 400}
15741574

1575+
# Create SQLAlchemy engine based on connection type
1576+
if "psycopg2" in db_conn.__class__.__module__:
1577+
engine = create_engine("postgresql://caug:gobears@localhost/npc_test")
1578+
else:
1579+
engine = create_engine("sqlite:///test_sqlite.db")
1580+
15751581
if choice == 1: # Direct answer query
15761582
try:
1577-
df = pd.read_sql_query(query, db_conn)
1583+
df = pd.read_sql_query(query, engine)
15781584
result = check_output_sufficient(
15791585
request, df, query, model=model, provider=provider, npc=npc
15801586
)
@@ -1591,7 +1597,7 @@ def process_data_output(
15911597

15921598
elif choice == 2: # Exploratory query
15931599
try:
1594-
df = pd.read_sql_query(query, db_conn)
1600+
df = pd.read_sql_query(query, engine)
15951601
extra_context = f"""
15961602
Exploratory query results:
15971603
Query: {query}
@@ -1621,7 +1627,7 @@ def process_data_output(
16211627

16221628
def get_data_response(
16231629
request: str,
1624-
db_conn: sqlite3.Connection,
1630+
db_conn,
16251631
tables: str = None,
16261632
n_try_freq: int = 5,
16271633
extra_context: str = None,
@@ -1634,9 +1640,73 @@ def get_data_response(
16341640
"""
16351641
Generate a response to a data request, with retries for failed attempts.
16361642
"""
1643+
1644+
# Extract schema information based on connection type
1645+
schema_info = ""
1646+
if "psycopg2" in db_conn.__class__.__module__:
1647+
cursor = db_conn.cursor()
1648+
# Get all tables and their columns
1649+
cursor.execute(
1650+
"""
1651+
SELECT
1652+
t.table_name,
1653+
array_agg(c.column_name || ' ' || c.data_type) as columns,
1654+
array_agg(
1655+
CASE
1656+
WHEN tc.constraint_type = 'FOREIGN KEY'
1657+
THEN kcu.column_name || ' REFERENCES ' || ccu.table_name || '.' || ccu.column_name
1658+
ELSE NULL
1659+
END
1660+
) as foreign_keys
1661+
FROM information_schema.tables t
1662+
JOIN information_schema.columns c ON t.table_name = c.table_name
1663+
LEFT JOIN information_schema.table_constraints tc
1664+
ON t.table_name = tc.table_name
1665+
AND tc.constraint_type = 'FOREIGN KEY'
1666+
LEFT JOIN information_schema.key_column_usage kcu
1667+
ON tc.constraint_name = kcu.constraint_name
1668+
LEFT JOIN information_schema.constraint_column_usage ccu
1669+
ON tc.constraint_name = ccu.constraint_name
1670+
WHERE t.table_schema = 'public'
1671+
GROUP BY t.table_name;
1672+
"""
1673+
)
1674+
for table, columns, fks in cursor.fetchall():
1675+
schema_info += f"\nTable {table}:\n"
1676+
schema_info += "Columns:\n"
1677+
for col in columns:
1678+
schema_info += f" - {col}\n"
1679+
if any(fk for fk in fks if fk is not None):
1680+
schema_info += "Foreign Keys:\n"
1681+
for fk in fks:
1682+
if fk:
1683+
schema_info += f" - {fk}\n"
1684+
1685+
elif "sqlite3" in db_conn.__class__.__module__:
1686+
cursor = db_conn.cursor()
1687+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
1688+
tables = cursor.fetchall()
1689+
for (table_name,) in tables:
1690+
schema_info += f"\nTable {table_name}:\n"
1691+
cursor.execute(f"PRAGMA table_info({table_name});")
1692+
columns = cursor.fetchall()
1693+
schema_info += "Columns:\n"
1694+
for col in columns:
1695+
schema_info += f" - {col[1]} {col[2]}\n"
1696+
1697+
cursor.execute(f"PRAGMA foreign_key_list({table_name});")
1698+
foreign_keys = cursor.fetchall()
1699+
if foreign_keys:
1700+
schema_info += "Foreign Keys:\n"
1701+
for fk in foreign_keys:
1702+
schema_info += f" - {fk[3]} REFERENCES {fk[2]}({fk[4]})\n"
1703+
16371704
prompt = f"""
16381705
User request: {request}
1639-
Available tables: {tables or 'Not specified'}
1706+
1707+
Database Schema:
1708+
{schema_info}
1709+
16401710
{extra_context or ''}
16411711
{f'Query history: {history}' if history else ''}
16421712
@@ -1655,49 +1725,47 @@ def get_data_response(
16551725

16561726
failures = []
16571727
for attempt in range(max_retries):
1658-
try:
1659-
llm_response = get_llm_response(
1660-
prompt, npc=npc, format="json", model=model, provider=provider
1661-
)
1728+
# try:
1729+
llm_response = get_llm_response(
1730+
prompt, npc=npc, format="json", model=model, provider=provider
1731+
)
16621732

1663-
# Clean response if it's a string
1664-
response_data = llm_response.get("response", {})
1665-
if isinstance(response_data, str):
1666-
response_data = (
1667-
response_data.replace("```json", "").replace("```", "").strip()
1668-
)
1669-
try:
1670-
response_data = json.loads(response_data)
1671-
except json.JSONDecodeError:
1672-
failures.append("Invalid JSON response")
1673-
continue
1674-
1675-
result = process_data_output(
1676-
response_data,
1677-
db_conn,
1678-
request,
1679-
tables=tables,
1680-
history=failures,
1681-
npc=npc,
1682-
model=model,
1683-
provider=provider,
1733+
# Clean response if it's a string
1734+
response_data = llm_response.get("response", {})
1735+
if isinstance(response_data, str):
1736+
response_data = (
1737+
response_data.replace("```json", "").replace("```", "").strip()
16841738
)
1739+
try:
1740+
response_data = json.loads(response_data)
1741+
except json.JSONDecodeError:
1742+
failures.append("Invalid JSON response")
1743+
continue
1744+
1745+
result = process_data_output(
1746+
response_data,
1747+
db_conn,
1748+
request,
1749+
tables=tables,
1750+
history=failures,
1751+
npc=npc,
1752+
model=model,
1753+
provider=provider,
1754+
)
16851755

1686-
if result["code"] == 200:
1687-
return result
1688-
1689-
failures.append(result["response"])
1756+
if result["code"] == 200:
1757+
return result
16901758

1691-
if attempt == max_retries - 1:
1692-
return {
1693-
"response": f"Failed after {max_retries} attempts. Errors: {'; '.join(failures)}",
1694-
"code": 400,
1695-
}
1759+
failures.append(result["response"])
16961760

1697-
except Exception as e:
1698-
failures.append(str(e))
1761+
if attempt == max_retries - 1:
1762+
return {
1763+
"response": f"Failed after {max_retries} attempts. Errors: {'; '.join(failures)}",
1764+
"code": 400,
1765+
}
16991766

1700-
return {"response": "Max retries exceeded", "code": 400}
1767+
# except Exception as e:
1768+
# failures.append(str(e))
17011769

17021770

17031771
def enter_reasoning_human_in_the_loop(

npcsh/npc_compiler.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -788,11 +788,29 @@ def __init__(
788788
self.model = model
789789
self.db_conn = db_conn
790790
if self.db_conn is not None:
791-
self.tables = self.db_conn.execute(
792-
"SELECT name, sql FROM sqlite_master WHERE type='table';"
793-
).fetchall()
791+
# Determine database type
792+
if "psycopg2" in self.db_conn.__class__.__module__:
793+
# PostgreSQL connection
794+
cursor = self.db_conn.cursor()
795+
cursor.execute(
796+
"""
797+
SELECT table_name, obj_description((quote_ident(table_name))::regclass, 'pg_class')
798+
FROM information_schema.tables
799+
WHERE table_schema='public';
800+
"""
801+
)
802+
self.tables = cursor.fetchall()
803+
self.db_type = "postgres"
804+
elif "sqlite3" in self.db_conn.__class__.__module__:
805+
# SQLite connection
806+
self.tables = self.db_conn.execute(
807+
"SELECT name, sql FROM sqlite_master WHERE type='table';"
808+
).fetchall()
809+
self.db_type = "sqlite"
794810
else:
795811
self.tables = None
812+
self.db_type = None
813+
796814
self.provider = provider
797815
self.api_url = api_url
798816
self.all_tools = all_tools or []
@@ -839,6 +857,45 @@ def __init__(
839857
else:
840858
self.parsed_npcs = []
841859

860+
def execute_query(self, query, params=None):
861+
"""Execute a query based on database type"""
862+
if self.db_type == "postgres":
863+
cursor = self.db_conn.cursor()
864+
cursor.execute(query, params or ())
865+
return cursor.fetchall()
866+
else: # sqlite
867+
cursor = self.db_conn.execute(query, params or ())
868+
return cursor.fetchall()
869+
870+
def _determine_db_type(self):
871+
"""Determine if the connection is PostgreSQL or SQLite"""
872+
# Check the connection object's class name
873+
conn_type = self.db_conn.__class__.__module__.lower()
874+
875+
if "psycopg" in conn_type:
876+
return "postgres"
877+
elif "sqlite" in conn_type:
878+
return "sqlite"
879+
else:
880+
raise ValueError(f"Unsupported database type: {conn_type}")
881+
882+
def _get_tables(self):
883+
"""Get table information based on database type"""
884+
if self.db_type == "postgres":
885+
cursor = self.db_conn.cursor()
886+
cursor.execute(
887+
"""
888+
SELECT table_name, obj_description((quote_ident(table_name))::regclass, 'pg_class') as description
889+
FROM information_schema.tables
890+
WHERE table_schema='public';
891+
"""
892+
)
893+
return cursor.fetchall()
894+
else: # sqlite
895+
return self.db_conn.execute(
896+
"SELECT name, sql FROM sqlite_master WHERE type='table';"
897+
).fetchall()
898+
842899
def get_memory(self):
843900
return
844901

npcsh/response.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ def get_openai_response(
230230

231231
# try:
232232
if api_key is None:
233-
api_key = os.environ["OPENAI_API_KEY"]
233+
api_key = os.environ.get("OPENAI_API_KEY", "")
234+
if len(api_key) == 0:
235+
raise ValueError("API key not found.")
234236
client = OpenAI(api_key=api_key)
235237
# print(npc)
236238

npcsh/search.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
from bs4 import BeautifulSoup
77
from duckduckgo_search import DDGS
8-
from googlesearch import search
8+
9+
try:
10+
from googlesearch import search
11+
except:
12+
pass
913
from typing import List, Dict, Any, Optional, Union
1014
import numpy as np
1115
import json

0 commit comments

Comments
 (0)