Skip to content

Commit

Permalink
Merge pull request #99 from crestalnetwork/improve/chat-log
Browse files Browse the repository at this point in the history
Improve: chat log
  • Loading branch information
taiyangc authored Jan 27, 2025
2 parents 7578480 + e3c47eb commit db30043
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 92 deletions.
1 change: 1 addition & 0 deletions abstracts/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class AgentState(TypedDict):
"""The state of the agent."""

messages: Annotated[Sequence[BaseMessage], add_messages]
need_clear: bool
is_last_step: IsLastStep
remaining_steps: RemainingSteps

Expand Down
75 changes: 45 additions & 30 deletions app/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from langgraph.graph.graph import CompiledGraph
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm.exc import NoResultFound
from sqlmodel import select

from abstracts.engine import AgentMessageInput
from abstracts.graph import AgentState
Expand Down Expand Up @@ -121,15 +120,25 @@ def initialize_agent(aid):
raise HTTPException(status_code=500, detail=str(e))

# ==== Initialize LLM.
input_token_limit = 120000
# TODO: model name whitelist
if agent.model.startswith("deepseek"):
llm = ChatOpenAI(
model_name=agent.model,
openai_api_key=config.deepseek_api_key,
openai_api_base="https://api.deepseek.com",
presence_penalty=1,
streaming=False,
timeout=90,
)
input_token_limit = 60000
else:
llm = ChatOpenAI(model_name=agent.model, openai_api_key=config.openai_api_key)
llm = ChatOpenAI(
model_name=agent.model,
openai_api_key=config.openai_api_key,
timeout=60,
presence_penalty=1,
)

# ==== Store buffered conversation history in memory.
memory = PostgresSaver(get_coon())
Expand Down Expand Up @@ -271,10 +280,13 @@ def initialize_agent(aid):
prompt_array.insert(0, ("system", twitter_prompt))
else:
prompt_array.append(("system", twitter_prompt))
if agent.prompt_append and not agent.model.startswith("deepseek"):
if agent.prompt_append:
# Escape any curly braces in prompt_append
escaped_append = agent.prompt_append.replace("{", "{{").replace("}", "}}")
prompt_array.append(("system", escaped_append))
if agent.model.startswith("deepseek"):
prompt_array.insert(0, ("system", escaped_append))
else:
prompt_array.append(("system", escaped_append))
prompt_temp = ChatPromptTemplate.from_messages(prompt_array)

def formatted_prompt(state: AgentState):
Expand All @@ -292,6 +304,7 @@ def formatted_prompt(state: AgentState):
checkpointer=memory,
state_modifier=formatted_prompt,
debug=config.debug_checkpoint,
input_token_limit=input_token_limit,
)


Expand Down Expand Up @@ -355,31 +368,33 @@ def execute_agent(
]
)
# debug prompt
if debug:
# get the agent from the database
with get_session() as db:
try:
agent: Agent = db.exec(select(Agent).filter(Agent.id == aid)).one()
except NoResultFound:
# Handle the case where the user is not found
raise HTTPException(status_code=404, detail="Agent not found")
except SQLAlchemyError as e:
# Handle other SQLAlchemy-related errors
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
try:
resp_debug_append = "\n===================\n\n[ system ]\n"
resp_debug_append += agent_prompt(agent)
snap = executor.get_state(stream_config)
if snap.values and "messages" in snap.values:
for msg in snap.values["messages"]:
resp_debug_append += f"[ {msg.type} ]\n{msg.content}\n\n"
if agent.prompt_append:
resp_debug_append += "[ system ]\n"
resp_debug_append += agent.prompt_append
except Exception as e:
logger.error(e)
resp_debug_append = ""
# if debug:
# # get the agent from the database
# with get_session() as db:
# try:
# agent: Agent = db.exec(select(Agent).filter(Agent.id == aid)).one()
# except NoResultFound:
# # Handle the case where the user is not found
# raise HTTPException(status_code=404, detail="Agent not found")
# except SQLAlchemyError as e:
# # Handle other SQLAlchemy-related errors
# logger.error(e)
# raise HTTPException(status_code=500, detail=str(e))
# try:
# resp_debug_append = "\n===================\n\n[ system ]\n"
# resp_debug_append += agent_prompt(agent)
# snap = executor.get_state(stream_config)
# if snap.values and "messages" in snap.values:
# for msg in snap.values["messages"]:
# resp_debug_append += f"[ {msg.type} ]\n{str(msg.content)}\n\n"
# if agent.prompt_append:
# resp_debug_append += "[ system ]\n"
# resp_debug_append += agent.prompt_append
# except Exception as e:
# logger.error(
# "failed to get debug prompt: " + str(e), exc_info=True, stack_info=True
# )
# resp_debug_append = ""
# run
for chunk in executor.stream(
{"messages": [HumanMessage(content=content)]}, stream_config
Expand Down Expand Up @@ -407,7 +422,7 @@ def execute_agent(
total_time = time.perf_counter() - start
resp_debug.append(f"Total time cost: {total_time:.3f} seconds")
if debug:
resp_debug.append(resp_debug_append)
# resp_debug.append(resp_debug_append)
return resp_debug
else:
return resp
153 changes: 118 additions & 35 deletions app/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast

import tiktoken
from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import (
AIMessage,
Expand Down Expand Up @@ -143,6 +144,51 @@ def _validate_chat_history(
raise ValueError(error_message)


# Cache for tiktoken encoders
_TIKTOKEN_CACHE = {}


def _get_encoder(model_name: str = "gpt-4"):
"""Get cached tiktoken encoder."""
if model_name not in _TIKTOKEN_CACHE:
try:
_TIKTOKEN_CACHE[model_name] = tiktoken.encoding_for_model(model_name)
except KeyError:
_TIKTOKEN_CACHE[model_name] = tiktoken.get_encoding("cl100k_base")
return _TIKTOKEN_CACHE[model_name]


def _count_tokens(messages: Sequence[BaseMessage], model_name: str = "gpt-4") -> int:
"""Count the number of tokens in a list of messages."""
encoding = _get_encoder(model_name)

num_tokens = 0
for message in messages:
# Every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4

# Count tokens for basic message attributes
msg_dict = message.model_dump()
for key in ["content", "name", "function_call", "role"]:
value = msg_dict.get(key)
if value:
num_tokens += len(encoding.encode(str(value)))

# Count tokens for tool calls more efficiently
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
# Only encode essential parts of tool_call
if isinstance(tool_call, dict):
for key in ["name", "arguments"]:
if key in tool_call:
num_tokens += len(encoding.encode(str(tool_call[key])))
else:
# Handle tool_call object if it's not a dict
num_tokens += len(encoding.encode(str(tool_call)))

return num_tokens


def create_agent(
model: LanguageModelLike,
tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode],
Expand All @@ -154,6 +200,7 @@ def create_agent(
store: Optional[BaseStore] = None,
interrupt_before: Optional[list[str]] = None,
interrupt_after: Optional[list[str]] = None,
input_token_limit: int = 120000,
debug: bool = False,
) -> CompiledGraph:
"""Creates a graph that works with a chat model that utilizes tool calling.
Expand Down Expand Up @@ -260,43 +307,40 @@ class Agent,Tools otherClass

def default_memory_manager(state: AgentState) -> AgentState:
messages = state["messages"]
# logger.debug("Before memory manager: %s", messages)

# Merge adjacent HumanMessages
i = 0
while i < len(messages) - 1:
if isinstance(messages[i], HumanMessage) and isinstance(
messages[i + 1], HumanMessage
):
# Handle different content types
content1 = messages[i].content
content2 = messages[i + 1].content

# Convert to list if string
if isinstance(content1, str):
content1 = [content1]
if isinstance(content2, str):
content2 = [content2]
# If need_clear is True, mark all messages for removal
if "need_clear" in state and state["need_clear"]:
for index in range(len(messages)):
messages[index] = RemoveMessage(id=messages[index].id)
return state

# Merge the contents
messages[i].content = content1 + content2
# Count total tokens
total_tokens = _count_tokens(messages)
token_limit = (
state.get("input_token_limit", 120000) // 2
) # Half of the input token limit

# If over token limit, remove messages from front
if total_tokens > token_limit:
must_delete = 0
current_tokens = total_tokens
temp_messages = messages.copy()

# Calculate how many messages to delete
while current_tokens > token_limit and must_delete < len(temp_messages):
current_tokens -= _count_tokens([temp_messages[must_delete]])
must_delete += 1

# Ensure first remaining message is HumanMessage
while must_delete < len(messages) and not isinstance(
messages[must_delete], HumanMessage
):
must_delete += 1

# Remove the second message
messages.pop(i + 1)
else:
i += 1
# Mark messages for removal
for index in range(must_delete):
messages[index] = RemoveMessage(id=messages[index].id)

if len(messages) <= 100:
return state
must_delete = len(messages) - 100
for index, message in enumerate(messages):
if index < must_delete:
messages[index] = RemoveMessage(id=message.id)
elif not isinstance(message, HumanMessage):
messages[index] = RemoveMessage(id=message.id)
else:
break
# logger.debug("After memory manager: %s", messages)
return state

if memory_manager is None:
Expand All @@ -305,7 +349,33 @@ def default_memory_manager(state: AgentState) -> AgentState:
# Define the function that calls the model
def call_model(state: AgentState, config: RunnableConfig) -> AgentState:
_validate_chat_history(state["messages"])
response = model_runnable.invoke(state, config)

try:
logger.debug("Starting model invocation...")
response = model_runnable.invoke(state, config)
logger.debug(f"Model invocation completed. Response type: {type(response)}")

# Log response details
if isinstance(response, AIMessage):
has_tool_calls = bool(response.tool_calls)
logger.debug(f"Response is AIMessage. Has tool calls: {has_tool_calls}")
if has_tool_calls:
logger.debug(f"Number of tool calls: {len(response.tool_calls)}")
else:
logger.debug(f"Response is not AIMessage: {type(response)}")

except Exception as e:
logger.error(f"Error in call model: {e}", exc_info=True)
# Clean message history on error
return {
"need_clear": True,
"messages": [
AIMessage(
content=f"Sorry, something went wrong. {e}",
)
],
}

has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
Expand Down Expand Up @@ -338,11 +408,24 @@ def call_model(state: AgentState, config: RunnableConfig) -> AgentState:
]
}
# We return a list, because this will get added to the existing list
logger.debug(f"Response: {response}")
return {"messages": [response]}

async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
_validate_chat_history(state["messages"])
response = await model_runnable.ainvoke(state, config)
try:
response = await model_runnable.ainvoke(state, config)
except Exception as e:
logger.error(f"Error in async call model: {e}")
# Clean message history on error
return {
"messages": [
AIMessage(
content=f"Sorry, something went wrong. {e}",
)
],
"need_clear": True,
}
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
Expand Down
14 changes: 14 additions & 0 deletions skills/twitter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ class TwitterBaseTool(IntentKitSkill):
)
store: SkillStoreABC = Field(description="The skill store for persisting data")

def _get_error_with_username(self, error_msg: str) -> str:
"""Get error message with username if available.
Args:
error_msg: The original error message.
Returns:
Error message with username if available.
"""
username = self.twitter.get_username()
if username:
return f"Error for Twitter user @{username}: {error_msg}"
return error_msg

def check_rate_limit(
self, max_requests: int = 1, interval: int = 15
) -> tuple[bool, str | None]:
Expand Down
14 changes: 10 additions & 4 deletions skills/twitter/follow_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,19 @@ def _run(self, user_id: str) -> TwitterFollowUserOutput:
)
if is_rate_limited:
return TwitterFollowUserOutput(
success=False, message=f"Error following user: {error}"
success=False,
message=self._get_error_with_username(
f"Error following user: {error}"
),
)

client = self.twitter.get_client()
if not client:
return TwitterFollowUserOutput(
success=False,
message="Failed to get Twitter client. Please check your authentication.",
message=self._get_error_with_username(
"Failed to get Twitter client. Please check your authentication."
),
)

# Follow the user using tweepy client
Expand All @@ -73,12 +78,13 @@ def _run(self, user_id: str) -> TwitterFollowUserOutput:
success=True, message=f"Successfully followed user {user_id}"
)
return TwitterFollowUserOutput(
success=False, message="Failed to follow user."
success=False,
message=self._get_error_with_username("Failed to follow user."),
)

except Exception as e:
return TwitterFollowUserOutput(
success=False, message=f"Error following user: {str(e)}"
success=False, message=self._get_error_with_username(str(e))
)

async def _arun(self, user_id: str) -> TwitterFollowUserOutput:
Expand Down
Loading

0 comments on commit db30043

Please sign in to comment.