Skip to content

Commit 090d533

Browse files
authored
Merge pull request #129 from cagostino/chris/npc_upgrades
Chris/npc upgrades
2 parents dbf14c7 + db7bd35 commit 090d533

16 files changed

+1061
-413
lines changed

README.md

Lines changed: 242 additions & 73 deletions
Large diffs are not rendered by default.

npcsh/cli.py

Lines changed: 410 additions & 56 deletions
Large diffs are not rendered by default.

npcsh/conversation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def get_ollama_conversation(
2121
npc: Any = None,
2222
tools: list = None,
2323
images=None,
24+
**kwargs,
2425
) -> List[Dict[str, str]]:
2526
"""
2627
Function Description:

npcsh/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def ensure_npcshrc_exists() -> str:
8787
with open(npcshrc_path, "w") as npcshrc:
8888
npcshrc.write("# NPCSH Configuration File\n")
8989
npcshrc.write("export NPCSH_INITIALIZED=0\n")
90+
npcshrc.write("export NPCSH_DEFAULT_MODE='chat'\n")
9091
npcshrc.write("export NPCSH_CHAT_PROVIDER='ollama'\n")
9192
npcshrc.write("export NPCSH_CHAT_MODEL='llama3.2'\n")
9293
npcshrc.write("export NPCSH_REASONING_PROVIDER='ollama'\n")
@@ -99,6 +100,7 @@ def ensure_npcshrc_exists() -> str:
99100
npcshrc.write(
100101
"export NPCSH_IMAGE_GEN_MODEL='runwayml/stable-diffusion-v1-5'\n"
101102
)
103+
102104
npcshrc.write("export NPCSH_IMAGE_GEN_PROVIDER='diffusers'\n")
103105

104106
npcshrc.write("export NPCSH_API_URL=''\n")

npcsh/image.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Dict, Any
77
from PIL import ImageGrab # Import ImageGrab from Pillow
88

9+
from .npc_sysenv import NPCSH_VISION_MODEL, NPCSH_VISION_PROVIDER, NPCSH_API_URL
910
from .llm_funcs import get_llm_response, get_stream
1011
import os
1112

@@ -231,20 +232,22 @@ def analyze_image_base(
231232

232233

233234
def analyze_image(
234-
command_history: Any,
235235
user_prompt: str,
236236
file_path: str,
237237
filename: str,
238238
npc: Any = None,
239239
stream: bool = False,
240240
messages: list = None,
241-
**model_kwargs,
241+
model: str = NPCSH_VISION_MODEL,
242+
provider: str = NPCSH_VISION_PROVIDER,
243+
api_key: str = None,
244+
api_url: str = NPCSH_API_URL,
242245
) -> Dict[str, str]:
243246
"""
244247
Function Description:
245248
This function captures a screenshot, analyzes it using the LLM model, and returns the response.
246249
Args:
247-
command_history: The command history object to add the command to.
250+
248251
user_prompt: The user prompt to provide to the LLM model.
249252
file_path: The path to the image file.
250253
filename: The name of the image file.
@@ -271,17 +274,17 @@ def analyze_image(
271274

272275
else:
273276
response = get_llm_response(
274-
user_prompt, images=[image_info], npc=npc, **model_kwargs
277+
user_prompt,
278+
images=[image_info],
279+
npc=npc,
280+
model=model,
281+
provider=provider,
282+
api_url=api_url,
283+
api_key=api_key,
275284
)
276285

277286
print(response)
278287
# Add to command history *inside* the try block
279-
command_history.add_command(
280-
f"screenshot with prompt: {user_prompt}",
281-
["screenshot", npc.name if npc else ""],
282-
response,
283-
os.getcwd(),
284-
)
285288
return response
286289

287290
except Exception as e:

npcsh/llm_funcs.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
EMBEDDINGS_DB_PATH,
4747
NPCSH_EMBEDDING_MODEL,
4848
NPCSH_EMBEDDING_PROVIDER,
49+
NPCSH_DEFAULT_MODE,
4950
NPCSH_REASONING_MODEL,
5051
NPCSH_REASONING_PROVIDER,
5152
NPCSH_IMAGE_GEN_MODEL,
@@ -54,6 +55,8 @@
5455
NPCSH_VISION_MODEL,
5556
NPCSH_VISION_PROVIDER,
5657
chroma_client,
58+
available_reasoning_models,
59+
available_chat_models,
5760
)
5861

5962
from .stream import (
@@ -444,7 +447,6 @@ def get_conversation(
444447

445448
def execute_llm_question(
446449
command: str,
447-
command_history: Any,
448450
model: str = NPCSH_CHAT_MODEL,
449451
provider: str = NPCSH_CHAT_PROVIDER,
450452
api_url: str = NPCSH_API_URL,
@@ -528,13 +530,11 @@ def execute_llm_question(
528530
# print(f"LLM response: {output}")
529531
# print(f"Messages: {messages}")
530532
# print("type of output", type(output))
531-
command_history.add_command(command, [], output, location)
532533
return {"messages": messages, "output": output}
533534

534535

535536
def execute_llm_command(
536537
command: str,
537-
command_history: Any,
538538
model: Optional[str] = None,
539539
provider: Optional[str] = None,
540540
api_url: str = NPCSH_API_URL,
@@ -550,7 +550,7 @@ def execute_llm_command(
550550
This function executes an LLM command.
551551
Args:
552552
command (str): The command to execute.
553-
command_history (Any): The command history.
553+
554554
Keyword Args:
555555
model (Optional[str]): The model to use for executing the command.
556556
provider (Optional[str]): The provider to use for executing the command.
@@ -654,6 +654,8 @@ def execute_llm_command(
654654
655655
{context}
656656
"""
657+
messages.append({"role": "user", "content": prompt})
658+
# print(messages, stream)
657659
if stream:
658660
response = get_stream(
659661
messages,
@@ -663,6 +665,7 @@ def execute_llm_command(
663665
api_key=api_key,
664666
npc=npc,
665667
)
668+
return response
666669

667670
else:
668671
response = get_llm_response(
@@ -677,7 +680,6 @@ def execute_llm_command(
677680
output = response.get("response", "")
678681

679682
# render_markdown(output)
680-
command_history.add_command(command, subcommands, output, location)
681683

682684
return {"messages": messages, "output": output}
683685
except subprocess.CalledProcessError as e:
@@ -734,7 +736,6 @@ def execute_llm_command(
734736

735737
attempt += 1
736738

737-
command_history.add_command(command, subcommands, "Execution failed", location)
738739
return {
739740
"messages": messages,
740741
"output": "Max attempts reached. Unable to execute the command successfully.",
@@ -743,9 +744,10 @@ def execute_llm_command(
743744

744745
def check_llm_command(
745746
command: str,
746-
command_history: Any,
747747
model: str = NPCSH_CHAT_MODEL,
748748
provider: str = NPCSH_CHAT_PROVIDER,
749+
reasoning_model: str = NPCSH_REASONING_MODEL,
750+
reasoning_provider: str = NPCSH_REASONING_PROVIDER,
749751
api_url: str = NPCSH_API_URL,
750752
api_key: str = None,
751753
npc: Any = None,
@@ -760,7 +762,6 @@ def check_llm_command(
760762
This function checks an LLM command.
761763
Args:
762764
command (str): The command to check.
763-
command_history (Any): The command history.
764765
Keyword Args:
765766
model (str): The model to use for checking the command.
766767
provider (str): The provider to use for checking the command.
@@ -771,6 +772,21 @@ def check_llm_command(
771772
Any: The result of checking the LLM command.
772773
"""
773774

775+
ENTER_REASONING_FLOW = False
776+
if NPCSH_DEFAULT_MODE == "reasoning":
777+
ENTER_REASONING_FLOW = True
778+
if model in available_reasoning_models:
779+
print(
780+
"""
781+
Model provided is a reasoning model, defaulting to non reasoning model for
782+
ReAct choices then will enter reasoning flow
783+
"""
784+
)
785+
reasoning_model = model
786+
reasoning_provider = provider
787+
788+
model = NPCSH_CHAT_MODEL
789+
provider = NPCSH_CHAT_PROVIDER
774790
if messages is None:
775791
messages = []
776792

@@ -799,6 +815,7 @@ def check_llm_command(
799815
800816
Available tools:
801817
"""
818+
802819
if npc.all_tools_dict is None:
803820
prompt += "No tools available."
804821
else:
@@ -899,14 +916,15 @@ def check_llm_command(
899916

900917
# Proceed according to the action specified
901918
action = response_content_parsed.get("action")
902-
919+
explanation = response_content["explanation"]
903920
# Include the user's command in the conversation messages
921+
print(f"action chosen: {action}")
922+
print(f"explanation given: {explanation}")
904923

905924
if action == "execute_command":
906925
# Pass messages to execute_llm_command
907926
result = execute_llm_command(
908927
command,
909-
command_history,
910928
model=model,
911929
provider=provider,
912930
api_url=api_url,
@@ -916,6 +934,8 @@ def check_llm_command(
916934
retrieved_docs=retrieved_docs,
917935
stream=stream,
918936
)
937+
if stream:
938+
return result
919939

920940
output = result.get("output", "")
921941
messages = result.get("messages", messages)
@@ -924,10 +944,10 @@ def check_llm_command(
924944
elif action == "invoke_tool":
925945
tool_name = response_content_parsed.get("tool_name")
926946
# print(npc)
947+
927948
result = handle_tool_call(
928949
command,
929950
tool_name,
930-
command_history,
931951
model=model,
932952
provider=provider,
933953
api_url=api_url,
@@ -944,19 +964,26 @@ def check_llm_command(
944964
return {"messages": messages, "output": output}
945965

946966
elif action == "answer_question":
947-
result = execute_llm_question(
948-
command,
949-
command_history,
950-
model=model,
951-
provider=provider,
952-
api_url=api_url,
953-
api_key=api_key,
954-
messages=messages,
955-
npc=npc,
956-
retrieved_docs=retrieved_docs,
957-
stream=stream,
958-
images=images,
959-
)
967+
968+
if ENTER_REASONING_FLOW:
969+
print("entering reasoning flow")
970+
result = enter_reasoning_human_in_the_loop(
971+
messages, reasoning_model, reasoning_provider
972+
)
973+
else:
974+
result = execute_llm_question(
975+
command,
976+
model=model,
977+
provider=provider,
978+
api_url=api_url,
979+
api_key=api_key,
980+
messages=messages,
981+
npc=npc,
982+
retrieved_docs=retrieved_docs,
983+
stream=stream,
984+
images=images,
985+
)
986+
960987
if stream:
961988
return result
962989
messages = result.get("messages", messages)
@@ -969,7 +996,6 @@ def check_llm_command(
969996
return npc.handle_agent_pass(
970997
npc_to_pass,
971998
command,
972-
command_history,
973999
messages=messages,
9741000
retrieved_docs=retrieved_docs,
9751001
n_docs=n_docs,
@@ -1005,7 +1031,6 @@ def check_llm_command(
10051031

10061032
return check_llm_command(
10071033
command + " \n \n \n extra context: " + request_input,
1008-
command_history,
10091034
model=model,
10101035
provider=provider,
10111036
api_url=api_url,
@@ -1014,6 +1039,7 @@ def check_llm_command(
10141039
messages=messages,
10151040
retrieved_docs=retrieved_docs,
10161041
n_docs=n_docs,
1042+
stream=stream,
10171043
)
10181044

10191045
elif action == "execute_sequence":
@@ -1024,7 +1050,6 @@ def check_llm_command(
10241050
result = handle_tool_call(
10251051
command,
10261052
tool_name,
1027-
command_history,
10281053
model=model,
10291054
provider=provider,
10301055
api_url=api_url,
@@ -1049,7 +1074,6 @@ def check_llm_command(
10491074
def handle_tool_call(
10501075
command: str,
10511076
tool_name: str,
1052-
command_history: Any,
10531077
model: str = NPCSH_CHAT_MODEL,
10541078
provider: str = NPCSH_CHAT_PROVIDER,
10551079
api_url: str = NPCSH_API_URL,
@@ -1068,7 +1092,6 @@ def handle_tool_call(
10681092
Args:
10691093
command (str): The command.
10701094
tool_name (str): The tool name.
1071-
command_history (Any): The command history.
10721095
Keyword Args:
10731096
model (str): The model to use for handling the tool call.
10741097
provider (str): The provider to use for handling the tool call.
@@ -1081,7 +1104,6 @@ def handle_tool_call(
10811104
the tool call.
10821105
10831106
"""
1084-
print(f"handle_tool_call invoked with tool_name: {tool_name}")
10851107
# print(npc)
10861108
if not npc or not npc.all_tools_dict:
10871109
print("not available")
@@ -1168,7 +1190,6 @@ def handle_tool_call(
11681190
return handle_tool_call(
11691191
command,
11701192
tool_name,
1171-
command_history,
11721193
model=model,
11731194
provider=provider,
11741195
messages=messages,
@@ -1213,7 +1234,6 @@ def handle_tool_call(
12131234

12141235
def execute_data_operations(
12151236
query: str,
1216-
command_history: Any,
12171237
dataframes: Dict[str, pd.DataFrame],
12181238
npc: Any = None,
12191239
db_path: str = "~/npcsh_history.db",
@@ -1223,7 +1243,7 @@ def execute_data_operations(
12231243
This function executes data operations.
12241244
Args:
12251245
query (str): The query to execute.
1226-
command_history (Any): The command history.
1246+
12271247
dataframes (Dict[str, pd.DataFrame]): The dictionary of dataframes.
12281248
Keyword Args:
12291249
npc (Any): The NPC object.

0 commit comments

Comments
 (0)