2
2
import json
3
3
import logging
4
4
import re
5
- from typing import Any , List
5
+ from typing import Any , List , Optional
6
6
7
7
from decouple import config
8
8
from langchain .agents import AgentType , initialize_agent
26
26
from app .tools .datasource import DatasourceTool , StructuredDatasourceTool
27
27
from app .utils .helpers import get_first_non_null
28
28
from app .utils .llm import LLM_MAPPING
29
+ from prisma .enums import LLMProvider , MemoryDbProvider
29
30
from prisma .models import LLM , Agent , AgentDatasource , AgentTool , MemoryDb
30
31
31
32
logger = logging .getLogger (__name__ )
@@ -152,7 +153,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
152
153
** (self .llm_params .dict () if self .llm_params else {}),
153
154
}
154
155
155
- if llm .provider == " OPENAI" :
156
+ if llm .provider == LLMProvider . OPENAI :
156
157
return ChatOpenAI (
157
158
model = LLM_MAPPING [model ],
158
159
openai_api_key = llm .apiKey ,
@@ -161,7 +162,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
161
162
** (llm .options if llm .options else {}),
162
163
** (llm_params ),
163
164
)
164
- elif llm .provider == " AZURE_OPENAI" :
165
+ elif llm .provider == LLMProvider . AZURE_OPENAI :
165
166
return AzureChatOpenAI (
166
167
api_key = llm .apiKey ,
167
168
streaming = self .enable_streaming ,
@@ -197,15 +198,19 @@ async def _get_prompt(self, agent: Agent) -> str:
197
198
content = f"{ content } " f"\n \n { datetime .datetime .now ().strftime ('%Y-%m-%d' )} "
198
199
return SystemMessage (content = content )
199
200
200
- async def _get_memory (self , memory_db : MemoryDb ) -> List :
201
+ async def _get_memory (self , memory_db : Optional [ MemoryDb ] ) -> List :
201
202
logger .debug (f"Use memory config: { memory_db } " )
202
203
if memory_db is None :
203
- memory_provider = config ("MEMORY" )
204
+ memory_provider = config ("MEMORY" , "motorhead" )
204
205
options = {}
205
206
else :
206
207
memory_provider = memory_db .provider
207
208
options = memory_db .options
208
- if memory_provider == "REDIS" or memory_provider == "redis" :
209
+
210
+ memory_provider = memory_provider .upper ()
211
+ logger .info (f"Using memory provider: { memory_provider } " )
212
+
213
+ if memory_provider == MemoryDbProvider .REDIS :
209
214
memory = ConversationBufferWindowMemory (
210
215
chat_memory = RedisChatMessageHistory (
211
216
session_id = (
@@ -227,18 +232,25 @@ async def _get_memory(self, memory_db: MemoryDb) -> List:
227
232
config ("REDIS_MEMORY_WINDOW" , 10 ),
228
233
),
229
234
)
230
- elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead" :
235
+ elif memory_provider == MemoryDbProvider .MOTORHEAD :
236
+ url = get_first_non_null (
237
+ options .get ("MEMORY_API_URL" ),
238
+ config ("MEMORY_API_URL" ),
239
+ )
240
+
241
+ if not url :
242
+ raise ValueError (
243
+ "Memory API URL is required for Motorhead memory provider"
244
+ )
245
+
231
246
memory = MotorheadMemory (
232
247
session_id = (
233
248
f"{ self .agent_id } -{ self .session_id } "
234
249
if self .session_id
235
250
else f"{ self .agent_id } "
236
251
),
237
252
memory_key = "chat_history" ,
238
- url = get_first_non_null (
239
- options .get ("MEMORY_API_URL" ),
240
- config ("MEMORY_API_URL" ),
241
- ),
253
+ url = url ,
242
254
return_messages = True ,
243
255
output_key = "output" ,
244
256
)
0 commit comments