Skip to content

Commit

Permalink
fix: adjust parameters to avoid deepseek bug
Browse files Browse the repository at this point in the history
  • Loading branch information
hyacinthus committed Jan 26, 2025
1 parent 2a7ad7c commit 65f4605
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
10 changes: 9 additions & 1 deletion app/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,18 @@ def initialize_agent(aid):
model_name=agent.model,
openai_api_key=config.deepseek_api_key,
openai_api_base="https://api.deepseek.com",
presence_penalty=1,
streaming=False,
timeout=90,
)
input_token_limit = 60000
else:
llm = ChatOpenAI(model_name=agent.model, openai_api_key=config.openai_api_key)
llm = ChatOpenAI(
model_name=agent.model,
openai_api_key=config.openai_api_key,
timeout=60,
presence_penalty=1,
)

# ==== Store buffered conversation history in memory.
memory = PostgresSaver(get_coon())
Expand Down
42 changes: 23 additions & 19 deletions app/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _validate_chat_history(
# Cache for tiktoken encoders
_TIKTOKEN_CACHE = {}


def _get_encoder(model_name: str = "gpt-4"):
"""Get cached tiktoken encoder."""
if model_name not in _TIKTOKEN_CACHE:
Expand All @@ -156,22 +157,23 @@ def _get_encoder(model_name: str = "gpt-4"):
_TIKTOKEN_CACHE[model_name] = tiktoken.get_encoding("cl100k_base")
return _TIKTOKEN_CACHE[model_name]


def _count_tokens(messages: Sequence[BaseMessage], model_name: str = "gpt-4") -> int:
"""Count the number of tokens in a list of messages."""
encoding = _get_encoder(model_name)

num_tokens = 0
for message in messages:
# Every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4

# Count tokens for basic message attributes
msg_dict = message.dict()
msg_dict = message.model_dump()
for key in ["content", "name", "function_call", "role"]:
value = msg_dict.get(key)
if value:
num_tokens += len(encoding.encode(str(value)))

# Count tokens for tool calls more efficiently
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
Expand All @@ -183,17 +185,18 @@ def _count_tokens(messages: Sequence[BaseMessage], model_name: str = "gpt-4") ->
else:
# Handle tool_call object if it's not a dict
num_tokens += len(encoding.encode(str(tool_call)))

return num_tokens


def _limit_tokens(
messages: list[BaseMessage], max_tokens: int = 120000
) -> list[BaseMessage]:
"""Limit the total number of tokens in messages by removing old messages.
Also merges adjacent HumanMessages as some models don't allow consecutive user messages."""
original_count = len(messages)
logger.debug(f"Starting token limiting. Original message count: {original_count}")

# First merge adjacent HumanMessages
i = 0
merge_count = 0
Expand All @@ -219,16 +222,16 @@ def _limit_tokens(
merge_count += 1
else:
i += 1

if merge_count > 0:
logger.debug(f"Merged {merge_count} adjacent human messages")

token_count = _count_tokens(messages)
logger.debug(f"Current token count: {token_count}, max allowed: {max_tokens}")

if token_count <= max_tokens:
return messages

# Modify messages in-place instead of creating a copy
i = 1
removed_count = 0
Expand All @@ -238,16 +241,17 @@ def _limit_tokens(
removed_count += 1
else:
i += 1

final_token_count = _count_tokens(messages)
logger.debug(
logger.info(
f"Token limiting complete. Removed {removed_count} messages. "
f"Final message count: {len(messages)}, "
f"Final token count: {final_token_count}"
)

return messages


def create_agent(
model: LanguageModelLike,
tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode],
Expand Down Expand Up @@ -387,29 +391,29 @@ def default_memory_manager(state: AgentState) -> AgentState:
# Define the function that calls the model
def call_model(state: AgentState, config: RunnableConfig) -> AgentState:
_validate_chat_history(state["messages"])

# Log message state before token limiting
msg_count = len(state["messages"])
token_count = _count_tokens(state["messages"])
logger.debug(
f"Before token limiting - Messages: {msg_count}, Tokens: {token_count}"
)

# Limit tokens before calling model
state["messages"] = _limit_tokens(state["messages"], input_token_limit)

# Log state after token limiting
msg_count = len(state["messages"])
token_count = _count_tokens(state["messages"])
logger.debug(
f"After token limiting - Messages: {msg_count}, Tokens: {token_count}"
)

try:
logger.debug("Starting model invocation...")
response = model_runnable.invoke(state, config)
logger.debug(f"Model invocation completed. Response type: {type(response)}")

# Log response details
if isinstance(response, AIMessage):
has_tool_calls = bool(response.tool_calls)
Expand All @@ -418,11 +422,11 @@ def call_model(state: AgentState, config: RunnableConfig) -> AgentState:
logger.debug(f"Number of tool calls: {len(response.tool_calls)}")
else:
logger.debug(f"Response is not AIMessage: {type(response)}")

except Exception as e:
logger.error(f"Error in call model: {e}", exc_info=True)
raise e

has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
all_tools_return_direct = (
all(call["name"] in should_return_direct for call in response.tool_calls)
Expand Down

0 comments on commit 65f4605

Please sign in to comment.