diff --git a/abstracts/graph.py b/abstracts/graph.py index f18147b..c3b3584 100644 --- a/abstracts/graph.py +++ b/abstracts/graph.py @@ -14,6 +14,7 @@ class AgentState(TypedDict): """The state of the agent.""" messages: Annotated[Sequence[BaseMessage], add_messages] + need_clear: bool is_last_step: IsLastStep remaining_steps: RemainingSteps diff --git a/app/core/engine.py b/app/core/engine.py index 78c540e..f5ef0a3 100644 --- a/app/core/engine.py +++ b/app/core/engine.py @@ -26,7 +26,6 @@ from langgraph.graph.graph import CompiledGraph from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.exc import NoResultFound -from sqlmodel import select from abstracts.engine import AgentMessageInput from abstracts.graph import AgentState @@ -121,15 +120,25 @@ def initialize_agent(aid): raise HTTPException(status_code=500, detail=str(e)) # ==== Initialize LLM. + input_token_limit = 120000 # TODO: model name whitelist if agent.model.startswith("deepseek"): llm = ChatOpenAI( model_name=agent.model, openai_api_key=config.deepseek_api_key, openai_api_base="https://api.deepseek.com", + presence_penalty=1, + streaming=False, + timeout=90, ) + input_token_limit = 60000 else: - llm = ChatOpenAI(model_name=agent.model, openai_api_key=config.openai_api_key) + llm = ChatOpenAI( + model_name=agent.model, + openai_api_key=config.openai_api_key, + timeout=60, + presence_penalty=1, + ) # ==== Store buffered conversation history in memory. memory = PostgresSaver(get_coon()) @@ -271,10 +280,13 @@ def initialize_agent(aid): prompt_array.insert(0, ("system", twitter_prompt)) else: prompt_array.append(("system", twitter_prompt)) - if agent.prompt_append and not agent.model.startswith("deepseek"): + if agent.prompt_append: # Escape any curly braces in prompt_append escaped_append = agent.prompt_append.replace("{", "{{").replace("}", "}}") - prompt_array.append(("system", escaped_append)) + if agent.model.startswith("deepseek"): + prompt_array.insert(0, ("system", escaped_append)) + else: + prompt_array.append(("system", escaped_append)) prompt_temp = ChatPromptTemplate.from_messages(prompt_array) def formatted_prompt(state: AgentState): @@ -292,6 +304,7 @@ def formatted_prompt(state: AgentState): checkpointer=memory, state_modifier=formatted_prompt, debug=config.debug_checkpoint, + input_token_limit=input_token_limit, ) @@ -355,31 +368,33 @@ def execute_agent( ] ) # debug prompt - if debug: - # get the agent from the database - with get_session() as db: - try: - agent: Agent = db.exec(select(Agent).filter(Agent.id == aid)).one() - except NoResultFound: - # Handle the case where the user is not found - raise HTTPException(status_code=404, detail="Agent not found") - except SQLAlchemyError as e: - # Handle other SQLAlchemy-related errors - logger.error(e) - raise HTTPException(status_code=500, detail=str(e)) - try: - resp_debug_append = "\n===================\n\n[ system ]\n" - resp_debug_append += agent_prompt(agent) - snap = executor.get_state(stream_config) - if snap.values and "messages" in snap.values: - for msg in snap.values["messages"]: - resp_debug_append += f"[ {msg.type} ]\n{msg.content}\n\n" - if agent.prompt_append: - resp_debug_append += "[ system ]\n" - resp_debug_append += agent.prompt_append - except Exception as e: - logger.error(e) - resp_debug_append = "" + # if debug: + # # get the agent from the database + # with get_session() as db: + # try: + # agent: Agent = db.exec(select(Agent).filter(Agent.id == aid)).one() + # except NoResultFound: + # # Handle the case where the user is not found + # raise HTTPException(status_code=404, detail="Agent not found") + # except SQLAlchemyError as e: + # # Handle other SQLAlchemy-related errors + # logger.error(e) + # raise HTTPException(status_code=500, detail=str(e)) + # try: + # resp_debug_append = "\n===================\n\n[ system ]\n" + # resp_debug_append += agent_prompt(agent) + # snap = executor.get_state(stream_config) + # if snap.values and "messages" in snap.values: + # for msg in snap.values["messages"]: + # resp_debug_append += f"[ {msg.type} ]\n{str(msg.content)}\n\n" + # if agent.prompt_append: + # resp_debug_append += "[ system ]\n" + # resp_debug_append += agent.prompt_append + # except Exception as e: + # logger.error( + # "failed to get debug prompt: " + str(e), exc_info=True, stack_info=True + # ) + # resp_debug_append = "" # run for chunk in executor.stream( {"messages": [HumanMessage(content=content)]}, stream_config @@ -407,7 +422,7 @@ def execute_agent( total_time = time.perf_counter() - start resp_debug.append(f"Total time cost: {total_time:.3f} seconds") if debug: - resp_debug.append(resp_debug_append) + # resp_debug.append(resp_debug_append) return resp_debug else: return resp diff --git a/app/core/graph.py b/app/core/graph.py index 84af613..814878b 100644 --- a/app/core/graph.py +++ b/app/core/graph.py @@ -3,6 +3,7 @@ import logging from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast +import tiktoken from langchain_core.language_models import BaseChatModel, LanguageModelLike from langchain_core.messages import ( AIMessage, @@ -143,6 +144,51 @@ def _validate_chat_history( raise ValueError(error_message) +# Cache for tiktoken encoders +_TIKTOKEN_CACHE = {} + + +def _get_encoder(model_name: str = "gpt-4"): + """Get cached tiktoken encoder.""" + if model_name not in _TIKTOKEN_CACHE: + try: + _TIKTOKEN_CACHE[model_name] = tiktoken.encoding_for_model(model_name) + except KeyError: + _TIKTOKEN_CACHE[model_name] = tiktoken.get_encoding("cl100k_base") + return _TIKTOKEN_CACHE[model_name] + + +def _count_tokens(messages: Sequence[BaseMessage], model_name: str = "gpt-4") -> int: + """Count the number of tokens in a list of messages.""" + encoding = _get_encoder(model_name) + + num_tokens = 0 + for message in messages: + # Every message follows {role/name}\n{content}\n + num_tokens += 4 + + # Count tokens for basic message attributes + msg_dict = message.model_dump() + for key in ["content", "name", "function_call", "role"]: + value = msg_dict.get(key) + if value: + num_tokens += len(encoding.encode(str(value))) + + # Count tokens for tool calls more efficiently + if hasattr(message, "tool_calls") and message.tool_calls: + for tool_call in message.tool_calls: + # Only encode essential parts of tool_call + if isinstance(tool_call, dict): + for key in ["name", "arguments"]: + if key in tool_call: + num_tokens += len(encoding.encode(str(tool_call[key]))) + else: + # Handle tool_call object if it's not a dict + num_tokens += len(encoding.encode(str(tool_call))) + + return num_tokens + + def create_agent( model: LanguageModelLike, tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode], @@ -154,6 +200,7 @@ def create_agent( store: Optional[BaseStore] = None, interrupt_before: Optional[list[str]] = None, interrupt_after: Optional[list[str]] = None, + input_token_limit: int = 120000, debug: bool = False, ) -> CompiledGraph: """Creates a graph that works with a chat model that utilizes tool calling. @@ -260,43 +307,40 @@ class Agent,Tools otherClass def default_memory_manager(state: AgentState) -> AgentState: messages = state["messages"] - # logger.debug("Before memory manager: %s", messages) - - # Merge adjacent HumanMessages - i = 0 - while i < len(messages) - 1: - if isinstance(messages[i], HumanMessage) and isinstance( - messages[i + 1], HumanMessage - ): - # Handle different content types - content1 = messages[i].content - content2 = messages[i + 1].content - # Convert to list if string - if isinstance(content1, str): - content1 = [content1] - if isinstance(content2, str): - content2 = [content2] + # If need_clear is True, mark all messages for removal + if "need_clear" in state and state["need_clear"]: + for index in range(len(messages)): + messages[index] = RemoveMessage(id=messages[index].id) + return state - # Merge the contents - messages[i].content = content1 + content2 + # Count total tokens + total_tokens = _count_tokens(messages) + token_limit = ( + state.get("input_token_limit", 120000) // 2 + ) # Half of the input token limit + + # If over token limit, remove messages from front + if total_tokens > token_limit: + must_delete = 0 + current_tokens = total_tokens + temp_messages = messages.copy() + + # Calculate how many messages to delete + while current_tokens > token_limit and must_delete < len(temp_messages): + current_tokens -= _count_tokens([temp_messages[must_delete]]) + must_delete += 1 + + # Ensure first remaining message is HumanMessage + while must_delete < len(messages) and not isinstance( + messages[must_delete], HumanMessage + ): + must_delete += 1 - # Remove the second message - messages.pop(i + 1) - else: - i += 1 + # Mark messages for removal + for index in range(must_delete): + messages[index] = RemoveMessage(id=messages[index].id) - if len(messages) <= 100: - return state - must_delete = len(messages) - 100 - for index, message in enumerate(messages): - if index < must_delete: - messages[index] = RemoveMessage(id=message.id) - elif not isinstance(message, HumanMessage): - messages[index] = RemoveMessage(id=message.id) - else: - break - # logger.debug("After memory manager: %s", messages) return state if memory_manager is None: @@ -305,7 +349,33 @@ def default_memory_manager(state: AgentState) -> AgentState: # Define the function that calls the model def call_model(state: AgentState, config: RunnableConfig) -> AgentState: _validate_chat_history(state["messages"]) - response = model_runnable.invoke(state, config) + + try: + logger.debug("Starting model invocation...") + response = model_runnable.invoke(state, config) + logger.debug(f"Model invocation completed. Response type: {type(response)}") + + # Log response details + if isinstance(response, AIMessage): + has_tool_calls = bool(response.tool_calls) + logger.debug(f"Response is AIMessage. Has tool calls: {has_tool_calls}") + if has_tool_calls: + logger.debug(f"Number of tool calls: {len(response.tool_calls)}") + else: + logger.debug(f"Response is not AIMessage: {type(response)}") + + except Exception as e: + logger.error(f"Error in call model: {e}", exc_info=True) + # Clean message history on error + return { + "need_clear": True, + "messages": [ + AIMessage( + content=f"Sorry, something went wrong. {e}", + ) + ], + } + has_tool_calls = isinstance(response, AIMessage) and response.tool_calls all_tools_return_direct = ( all(call["name"] in should_return_direct for call in response.tool_calls) @@ -338,11 +408,24 @@ def call_model(state: AgentState, config: RunnableConfig) -> AgentState: ] } # We return a list, because this will get added to the existing list + logger.debug(f"Response: {response}") return {"messages": [response]} async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState: _validate_chat_history(state["messages"]) - response = await model_runnable.ainvoke(state, config) + try: + response = await model_runnable.ainvoke(state, config) + except Exception as e: + logger.error(f"Error in async call model: {e}") + # Clean message history on error + return { + "messages": [ + AIMessage( + content=f"Sorry, something went wrong. {e}", + ) + ], + "need_clear": True, + } has_tool_calls = isinstance(response, AIMessage) and response.tool_calls all_tools_return_direct = ( all(call["name"] in should_return_direct for call in response.tool_calls) diff --git a/skills/twitter/base.py b/skills/twitter/base.py index d196b83..0fbe053 100644 --- a/skills/twitter/base.py +++ b/skills/twitter/base.py @@ -21,6 +21,20 @@ class TwitterBaseTool(IntentKitSkill): ) store: SkillStoreABC = Field(description="The skill store for persisting data") + def _get_error_with_username(self, error_msg: str) -> str: + """Get error message with username if available. + + Args: + error_msg: The original error message. + + Returns: + Error message with username if available. + """ + username = self.twitter.get_username() + if username: + return f"Error for Twitter user @{username}: {error_msg}" + return error_msg + def check_rate_limit( self, max_requests: int = 1, interval: int = 15 ) -> tuple[bool, str | None]: diff --git a/skills/twitter/follow_user.py b/skills/twitter/follow_user.py index a0863ab..123a10f 100644 --- a/skills/twitter/follow_user.py +++ b/skills/twitter/follow_user.py @@ -53,14 +53,19 @@ def _run(self, user_id: str) -> TwitterFollowUserOutput: ) if is_rate_limited: return TwitterFollowUserOutput( - success=False, message=f"Error following user: {error}" + success=False, + message=self._get_error_with_username( + f"Error following user: {error}" + ), ) client = self.twitter.get_client() if not client: return TwitterFollowUserOutput( success=False, - message="Failed to get Twitter client. Please check your authentication.", + message=self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ), ) # Follow the user using tweepy client @@ -73,12 +78,13 @@ def _run(self, user_id: str) -> TwitterFollowUserOutput: success=True, message=f"Successfully followed user {user_id}" ) return TwitterFollowUserOutput( - success=False, message="Failed to follow user." + success=False, + message=self._get_error_with_username("Failed to follow user."), ) except Exception as e: return TwitterFollowUserOutput( - success=False, message=f"Error following user: {str(e)}" + success=False, message=self._get_error_with_username(str(e)) ) async def _arun(self, user_id: str) -> TwitterFollowUserOutput: diff --git a/skills/twitter/get_mentions.py b/skills/twitter/get_mentions.py index e87cc23..f26072c 100644 --- a/skills/twitter/get_mentions.py +++ b/skills/twitter/get_mentions.py @@ -68,13 +68,18 @@ def _run(self) -> TwitterGetMentionsOutput: if not client: return TwitterGetMentionsOutput( mentions=[], - error="Failed to get Twitter client. Please check your authentication.", + error=self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ), ) user_id = self.twitter.get_id() if not user_id: return TwitterGetMentionsOutput( - mentions=[], error="Failed to get Twitter user ID." + mentions=[], + error=self._get_error_with_username( + "Failed to get Twitter user ID." + ), ) mentions = client.get_users_mentions( @@ -121,7 +126,9 @@ def _run(self) -> TwitterGetMentionsOutput: except Exception as e: logger.error("Error getting mentions: %s", str(e)) - return TwitterGetMentionsOutput(mentions=[], error=str(e)) + return TwitterGetMentionsOutput( + mentions=[], error=self._get_error_with_username(str(e)) + ) async def _arun(self) -> TwitterGetMentionsOutput: """Async implementation of the tool. diff --git a/skills/twitter/get_timeline.py b/skills/twitter/get_timeline.py index 9373781..ea7a6cd 100644 --- a/skills/twitter/get_timeline.py +++ b/skills/twitter/get_timeline.py @@ -72,7 +72,9 @@ def _run(self, max_results: int = 10) -> TwitterGetTimelineOutput: if not client: return TwitterGetTimelineOutput( tweets=[], - error="Failed to get Twitter client. Please check your authentication.", + error=self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ), ) timeline = client.get_home_timeline( @@ -120,7 +122,9 @@ def _run(self, max_results: int = 10) -> TwitterGetTimelineOutput: except Exception as e: logger.error("Error getting timeline: %s", str(e)) - return TwitterGetTimelineOutput(tweets=[], error=str(e)) + return TwitterGetTimelineOutput( + tweets=[], error=self._get_error_with_username(str(e)) + ) async def _arun(self) -> TwitterGetTimelineOutput: """Async implementation of the tool. diff --git a/skills/twitter/like_tweet.py b/skills/twitter/like_tweet.py index ceb8e4f..e6117a0 100644 --- a/skills/twitter/like_tweet.py +++ b/skills/twitter/like_tweet.py @@ -60,7 +60,9 @@ def _run(self, tweet_id: str) -> TwitterLikeTweetOutput: if not client: return TwitterLikeTweetOutput( success=False, - message="Failed to get Twitter client. Please check your authentication.", + message=self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ), ) # Like the tweet using tweepy client @@ -71,12 +73,13 @@ def _run(self, tweet_id: str) -> TwitterLikeTweetOutput: success=True, message=f"Successfully liked tweet {tweet_id}" ) return TwitterLikeTweetOutput( - success=False, message="Failed to like tweet." + success=False, + message=self._get_error_with_username("Failed to like tweet."), ) except Exception as e: return TwitterLikeTweetOutput( - success=False, message=f"Error liking tweet: {str(e)}" + success=False, message=self._get_error_with_username(str(e)) ) async def _arun(self, tweet_id: str) -> TwitterLikeTweetOutput: diff --git a/skills/twitter/post_tweet.py b/skills/twitter/post_tweet.py index 56f9330..d591bf8 100644 --- a/skills/twitter/post_tweet.py +++ b/skills/twitter/post_tweet.py @@ -51,7 +51,9 @@ def _run(self, text: str) -> str: client = self.twitter.get_client() if not client: - return "Failed to get Twitter client. Please check your authentication." + return self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ) # Post tweet using tweepy client response = client.create_tweet(text=text, user_auth=self.twitter.use_key) @@ -59,10 +61,10 @@ def _run(self, text: str) -> str: if "data" in response and "id" in response["data"]: tweet_id = response["data"]["id"] return f"Tweet posted successfully! Tweet ID: {tweet_id}" - return "Failed to post tweet." + return self._get_error_with_username("Failed to post tweet.") except Exception as e: - return f"Error posting tweet: {str(e)}" + return self._get_error_with_username(str(e)) async def _arun(self, text: str) -> str: """Async implementation of the tool. diff --git a/skills/twitter/reply_tweet.py b/skills/twitter/reply_tweet.py index edfb49a..784ae90 100644 --- a/skills/twitter/reply_tweet.py +++ b/skills/twitter/reply_tweet.py @@ -47,11 +47,15 @@ def _run(self, tweet_id: str, text: str) -> str: max_requests=48, interval=1440 ) if is_rate_limited: - return f"Error replying to tweet: {error}" + return self._get_error_with_username( + f"Error replying to tweet: {error}" + ) client = self.twitter.get_client() if not client: - return "Failed to get Twitter client. Please check your authentication." + return self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ) # Post reply tweet using tweepy client response = client.create_tweet( @@ -61,10 +65,10 @@ def _run(self, tweet_id: str, text: str) -> str: if "data" in response and "id" in response["data"]: reply_id = response["data"]["id"] return f"Reply posted successfully! Reply Tweet ID: {reply_id}" - return "Failed to post reply tweet." + return self._get_error_with_username("Failed to post reply tweet.") except Exception as e: - return f"Error posting reply tweet: {str(e)}" + return self._get_error_with_username(f"Error posting reply tweet: {str(e)}") async def _arun(self, tweet_id: str, text: str) -> str: """Async implementation of the tool. diff --git a/skills/twitter/retweet.py b/skills/twitter/retweet.py index 06e19f8..f851dfa 100644 --- a/skills/twitter/retweet.py +++ b/skills/twitter/retweet.py @@ -60,7 +60,9 @@ def _run(self, tweet_id: str) -> TwitterRetweetOutput: if not client: return TwitterRetweetOutput( success=False, - message="Failed to get Twitter client. Please check your authentication.", + message=self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ), ) # Get authenticated user's ID @@ -82,12 +84,13 @@ def _run(self, tweet_id: str) -> TwitterRetweetOutput: success=True, message=f"Successfully retweeted tweet {tweet_id}" ) return TwitterRetweetOutput( - success=False, message="Failed to retweet tweet." + success=False, + message=self._get_error_with_username("Failed to retweet."), ) except Exception as e: return TwitterRetweetOutput( - success=False, message=f"Error retweeting tweet: {str(e)}" + success=False, message=self._get_error_with_username(str(e)) ) async def _arun(self, tweet_id: str) -> TwitterRetweetOutput: diff --git a/skills/twitter/search_tweets.py b/skills/twitter/search_tweets.py index 25385b4..52881b8 100644 --- a/skills/twitter/search_tweets.py +++ b/skills/twitter/search_tweets.py @@ -59,13 +59,17 @@ def _run( max_requests=3, interval=15 ) if is_rate_limited: - return TwitterSearchTweetsOutput(tweets=[], error=error) + return TwitterSearchTweetsOutput( + tweets=[], error=self._get_error_with_username(error) + ) client = self.twitter.get_client() if not client: return TwitterSearchTweetsOutput( tweets=[], - error="Failed to get Twitter client. Please check your authentication.", + error=self._get_error_with_username( + "Failed to get Twitter client. Please check your authentication." + ), ) # Get since_id from store to avoid duplicate results @@ -104,7 +108,11 @@ def _run( try: result = self.process_tweets_response(tweets) except Exception as e: - logger.error("Error processing search results: %s", str(e)) + logger.error( + self._get_error_with_username( + f"Error processing search results: {e}" + ) + ) raise # Update the since_id in store for the next request @@ -115,8 +123,11 @@ def _run( return TwitterSearchTweetsOutput(tweets=result) except Exception as e: - logger.error("Error searching tweets: %s", str(e)) - return TwitterSearchTweetsOutput(tweets=[], error=str(e)) + logger.error(self._get_error_with_username(f"Error searching tweets: {e}")) + return TwitterSearchTweetsOutput( + tweets=[], + error=self._get_error_with_username(f"Error searching tweets: {e}"), + ) async def _arun(self, query: str) -> TwitterSearchTweetsOutput: """Async implementation of the tool.