From 65f4605c04b871374214e36adcedc54750accd83 Mon Sep 17 00:00:00 2001 From: Muninn Date: Sun, 26 Jan 2025 23:40:20 +0800 Subject: [PATCH] fix: adjust parameters to avoid deepseek bug --- app/core/engine.py | 10 +++++++++- app/core/graph.py | 42 +++++++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/app/core/engine.py b/app/core/engine.py index 21fc800..da508bb 100644 --- a/app/core/engine.py +++ b/app/core/engine.py @@ -128,10 +128,18 @@ def initialize_agent(aid): model_name=agent.model, openai_api_key=config.deepseek_api_key, openai_api_base="https://api.deepseek.com", + presence_penalty=1, + streaming=False, + timeout=90, ) input_token_limit = 60000 else: - llm = ChatOpenAI(model_name=agent.model, openai_api_key=config.openai_api_key) + llm = ChatOpenAI( + model_name=agent.model, + openai_api_key=config.openai_api_key, + timeout=60, + presence_penalty=1, + ) # ==== Store buffered conversation history in memory. memory = PostgresSaver(get_coon()) diff --git a/app/core/graph.py b/app/core/graph.py index 1bc1cfb..ad31cb9 100644 --- a/app/core/graph.py +++ b/app/core/graph.py @@ -147,6 +147,7 @@ def _validate_chat_history( # Cache for tiktoken encoders _TIKTOKEN_CACHE = {} + def _get_encoder(model_name: str = "gpt-4"): """Get cached tiktoken encoder.""" if model_name not in _TIKTOKEN_CACHE: @@ -156,22 +157,23 @@ def _get_encoder(model_name: str = "gpt-4"): _TIKTOKEN_CACHE[model_name] = tiktoken.get_encoding("cl100k_base") return _TIKTOKEN_CACHE[model_name] + def _count_tokens(messages: Sequence[BaseMessage], model_name: str = "gpt-4") -> int: """Count the number of tokens in a list of messages.""" encoding = _get_encoder(model_name) - + num_tokens = 0 for message in messages: # Every message follows {role/name}\n{content}\n num_tokens += 4 - + # Count tokens for basic message attributes - msg_dict = message.dict() + msg_dict = message.model_dump() for key in ["content", "name", "function_call", "role"]: value = msg_dict.get(key) if value: num_tokens += len(encoding.encode(str(value))) - + # Count tokens for tool calls more efficiently if hasattr(message, "tool_calls") and message.tool_calls: for tool_call in message.tool_calls: @@ -183,9 +185,10 @@ def _count_tokens(messages: Sequence[BaseMessage], model_name: str = "gpt-4") -> else: # Handle tool_call object if it's not a dict num_tokens += len(encoding.encode(str(tool_call))) - + return num_tokens + def _limit_tokens( messages: list[BaseMessage], max_tokens: int = 120000 ) -> list[BaseMessage]: @@ -193,7 +196,7 @@ def _limit_tokens( Also merges adjacent HumanMessages as some models don't allow consecutive user messages.""" original_count = len(messages) logger.debug(f"Starting token limiting. Original message count: {original_count}") - + # First merge adjacent HumanMessages i = 0 merge_count = 0 @@ -219,16 +222,16 @@ def _limit_tokens( merge_count += 1 else: i += 1 - + if merge_count > 0: logger.debug(f"Merged {merge_count} adjacent human messages") token_count = _count_tokens(messages) logger.debug(f"Current token count: {token_count}, max allowed: {max_tokens}") - + if token_count <= max_tokens: return messages - + # Modify messages in-place instead of creating a copy i = 1 removed_count = 0 @@ -238,16 +241,17 @@ def _limit_tokens( removed_count += 1 else: i += 1 - + final_token_count = _count_tokens(messages) - logger.debug( + logger.info( f"Token limiting complete. Removed {removed_count} messages. " f"Final message count: {len(messages)}, " f"Final token count: {final_token_count}" ) - + return messages + def create_agent( model: LanguageModelLike, tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode], @@ -387,29 +391,29 @@ def default_memory_manager(state: AgentState) -> AgentState: # Define the function that calls the model def call_model(state: AgentState, config: RunnableConfig) -> AgentState: _validate_chat_history(state["messages"]) - + # Log message state before token limiting msg_count = len(state["messages"]) token_count = _count_tokens(state["messages"]) logger.debug( f"Before token limiting - Messages: {msg_count}, Tokens: {token_count}" ) - + # Limit tokens before calling model state["messages"] = _limit_tokens(state["messages"], input_token_limit) - + # Log state after token limiting msg_count = len(state["messages"]) token_count = _count_tokens(state["messages"]) logger.debug( f"After token limiting - Messages: {msg_count}, Tokens: {token_count}" ) - + try: logger.debug("Starting model invocation...") response = model_runnable.invoke(state, config) logger.debug(f"Model invocation completed. Response type: {type(response)}") - + # Log response details if isinstance(response, AIMessage): has_tool_calls = bool(response.tool_calls) @@ -418,11 +422,11 @@ def call_model(state: AgentState, config: RunnableConfig) -> AgentState: logger.debug(f"Number of tool calls: {len(response.tool_calls)}") else: logger.debug(f"Response is not AIMessage: {type(response)}") - + except Exception as e: logger.error(f"Error in call model: {e}", exc_info=True) raise e - + has_tool_calls = isinstance(response, AIMessage) and response.tool_calls all_tools_return_direct = ( all(call["name"] in should_return_direct for call in response.tool_calls)