Skip to content

Commit c8c076d

Browse files
committed
chore: refactor code
1 parent acc65f6 commit c8c076d

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

libs/superagent/app/agents/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
2222
llm_params: Optional[LLMParams] = {},
2323
agent_config: Agent = None,
24-
memory_config: MemoryDb = None,
24+
memory_config: Optional[MemoryDb] = None,
2525
):
2626
self.agent_id = agent_id
2727
self.session_id = session_id

libs/superagent/app/agents/langchain.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import logging
44
import re
5-
from typing import Any, List
5+
from typing import Any, List, Optional
66

77
from decouple import config
88
from langchain.agents import AgentType, initialize_agent
@@ -26,6 +26,7 @@
2626
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
2727
from app.utils.helpers import get_first_non_null
2828
from app.utils.llm import LLM_MAPPING
29+
from prisma.enums import LLMProvider, MemoryDbProvider
2930
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb
3031

3132
logger = logging.getLogger(__name__)
@@ -152,7 +153,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
152153
**(self.llm_params.dict() if self.llm_params else {}),
153154
}
154155

155-
if llm.provider == "OPENAI":
156+
if llm.provider == LLMProvider.OPENAI:
156157
return ChatOpenAI(
157158
model=LLM_MAPPING[model],
158159
openai_api_key=llm.apiKey,
@@ -161,7 +162,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
161162
**(llm.options if llm.options else {}),
162163
**(llm_params),
163164
)
164-
elif llm.provider == "AZURE_OPENAI":
165+
elif llm.provider == LLMProvider.AZURE_OPENAI:
165166
return AzureChatOpenAI(
166167
api_key=llm.apiKey,
167168
streaming=self.enable_streaming,
@@ -197,15 +198,19 @@ async def _get_prompt(self, agent: Agent) -> str:
197198
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
198199
return SystemMessage(content=content)
199200

200-
async def _get_memory(self, memory_db: MemoryDb) -> List:
201+
async def _get_memory(self, memory_db: Optional[MemoryDb]) -> List:
201202
logger.debug(f"Use memory config: {memory_db}")
202203
if memory_db is None:
203-
memory_provider = config("MEMORY")
204+
memory_provider = config("MEMORY", "motorhead")
204205
options = {}
205206
else:
206207
memory_provider = memory_db.provider
207208
options = memory_db.options
208-
if memory_provider == "REDIS" or memory_provider == "redis":
209+
210+
memory_provider = memory_provider.upper()
211+
logger.info(f"Using memory provider: {memory_provider}")
212+
213+
if memory_provider == MemoryDbProvider.REDIS:
209214
memory = ConversationBufferWindowMemory(
210215
chat_memory=RedisChatMessageHistory(
211216
session_id=(
@@ -227,18 +232,25 @@ async def _get_memory(self, memory_db: MemoryDb) -> List:
227232
config("REDIS_MEMORY_WINDOW", 10),
228233
),
229234
)
230-
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead":
235+
elif memory_provider == MemoryDbProvider.MOTORHEAD:
236+
url = get_first_non_null(
237+
options.get("MEMORY_API_URL"),
238+
config("MEMORY_API_URL"),
239+
)
240+
241+
if not url:
242+
raise ValueError(
243+
"Memory API URL is required for Motorhead memory provider"
244+
)
245+
231246
memory = MotorheadMemory(
232247
session_id=(
233248
f"{self.agent_id}-{self.session_id}"
234249
if self.session_id
235250
else f"{self.agent_id}"
236251
),
237252
memory_key="chat_history",
238-
url=get_first_non_null(
239-
options.get("MEMORY_API_URL"),
240-
config("MEMORY_API_URL"),
241-
),
253+
url=url,
242254
return_messages=True,
243255
output_key="output",
244256
)

0 commit comments

Comments
 (0)