Skip to content

Commit 88ff32e

Browse files
committed
support set memory from ui
1 parent 5132154 commit 88ff32e

File tree

23 files changed

+575
-20
lines changed

23 files changed

+575
-20
lines changed

libs/superagent/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ ENV PORT="8080"
3131

3232
COPY --from=builder /app/.venv /app/.venv
3333

34-
COPY . ./
35-
3634
# Improve grpc error messages
3735
RUN pip install grpcio-status
3836

37+
COPY . ./
38+
3939
# Enable prisma migrations
4040
RUN prisma generate
4141

libs/superagent/app/agents/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from app.models.request import LLMParams
44
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
55
from prisma.enums import AgentType
6-
from prisma.models import Agent
6+
from prisma.models import Agent, MemoryDb
77

88
DEFAULT_PROMPT = (
99
"You are a helpful AI Assistant, answer the users questions to "
@@ -21,6 +21,7 @@ def __init__(
2121
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
2222
llm_params: Optional[LLMParams] = {},
2323
agent_config: Agent = None,
24+
memory_config: MemoryDb = None,
2425
):
2526
self.agent_id = agent_id
2627
self.session_id = session_id
@@ -29,6 +30,7 @@ def __init__(
2930
self.callbacks = callbacks
3031
self.llm_params = llm_params
3132
self.agent_config = agent_config
33+
self.memory_config = memory_config
3234

3335
async def _get_tools(
3436
self,
@@ -60,6 +62,7 @@ async def get_agent(self):
6062
callbacks=self.callbacks,
6163
llm_params=self.llm_params,
6264
agent_config=self.agent_config,
65+
memory_config=self.memory_config,
6366
)
6467

6568
elif self.agent_config.type == AgentType.LLM:
@@ -72,6 +75,7 @@ async def get_agent(self):
7275
callbacks=self.callbacks,
7376
llm_params=self.llm_params,
7477
agent_config=self.agent_config,
78+
memory_config=self.memory_config,
7579
)
7680

7781
else:
@@ -85,6 +89,7 @@ async def get_agent(self):
8589
callbacks=self.callbacks,
8690
llm_params=self.llm_params,
8791
agent_config=self.agent_config,
92+
memory_config=self.memory_config,
8893
)
8994

9095
return await agent.get_agent()

libs/superagent/app/agents/langchain.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import json
3+
import logging
34
import re
45
from typing import Any, List
56

@@ -23,8 +24,11 @@
2324
from app.models.tools import DatasourceInput
2425
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool
2526
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
27+
from app.utils.helpers import get_first_non_null
2628
from app.utils.llm import LLM_MAPPING
27-
from prisma.models import LLM, Agent, AgentDatasource, AgentTool
29+
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb
30+
31+
logger = logging.getLogger(__name__)
2832

2933
DEFAULT_PROMPT = (
3034
"You are a helpful AI Assistant, answer the users questions to "
@@ -193,33 +197,48 @@ async def _get_prompt(self, agent: Agent) -> str:
193197
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
194198
return SystemMessage(content=content)
195199

196-
async def _get_memory(self) -> List:
197-
memory_type = config("MEMORY", "motorhead")
198-
if memory_type == "redis":
200+
async def _get_memory(self, memory_db: MemoryDb) -> List:
201+
logger.debug(f"Use memory config: {memory_db}")
202+
if memory_db is None:
203+
memory_provider = config("MEMORY")
204+
options = {}
205+
else:
206+
memory_provider = memory_db.provider
207+
options = memory_db.options
208+
if memory_provider == "REDIS" or memory_provider == "redis":
199209
memory = ConversationBufferWindowMemory(
200210
chat_memory=RedisChatMessageHistory(
201211
session_id=(
202212
f"{self.agent_id}-{self.session_id}"
203213
if self.session_id
204214
else f"{self.agent_id}"
205215
),
206-
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
216+
url=get_first_non_null(
217+
options.get("REDIS_MEMORY_URL"),
218+
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
219+
),
207220
key_prefix="superagent:",
208221
),
209222
memory_key="chat_history",
210223
return_messages=True,
211224
output_key="output",
212-
k=config("REDIS_MEMORY_WINDOW", 10),
225+
k=get_first_non_null(
226+
options.get("REDIS_MEMORY_WINDOW"),
227+
config("REDIS_MEMORY_WINDOW", 10),
228+
),
213229
)
214-
else:
230+
elif memory_provider == "MOTORHEAD" or memory_provider == "motorhead":
215231
memory = MotorheadMemory(
216232
session_id=(
217233
f"{self.agent_id}-{self.session_id}"
218234
if self.session_id
219235
else f"{self.agent_id}"
220236
),
221237
memory_key="chat_history",
222-
url=config("MEMORY_API_URL"),
238+
url=get_first_non_null(
239+
options.get("MEMORY_API_URL"),
240+
config("MEMORY_API_URL"),
241+
),
223242
return_messages=True,
224243
output_key="output",
225244
)
@@ -235,7 +254,7 @@ async def get_agent(self):
235254
agent_tools=self.agent_config.tools,
236255
)
237256
prompt = await self._get_prompt(agent=self.agent_config)
238-
memory = await self._get_memory()
257+
memory = await self._get_memory(memory_db=self.memory_config)
239258

240259
if len(tools) > 0:
241260
agent = initialize_agent(

libs/superagent/app/api/agents.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,10 @@ async def invoke(
452452
if not model and metadata.get("model"):
453453
model = metadata.get("model")
454454

455+
memory_config = await prisma.memorydb.find_first(
456+
where={"provider": agent_config.memory, "apiUserId": api_user.id},
457+
)
458+
455459
def track_agent_invocation(result):
456460
intermediate_steps_to_obj = [
457461
{
@@ -571,6 +575,7 @@ async def send_message(
571575
callbacks=monitoring_callbacks,
572576
llm_params=body.llm_params,
573577
agent_config=agent_config,
578+
memory_config=memory_config,
574579
)
575580
agent = await agent_base.get_agent()
576581

libs/superagent/app/api/memory_dbs.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
3+
import segment.analytics as analytics
4+
from decouple import config
5+
from fastapi import APIRouter, Depends
6+
7+
from app.models.request import MemoryDb as MemoryDbRequest
8+
from app.models.response import MemoryDb as MemoryDbResponse
9+
from app.models.response import MemoryDbList as MemoryDbListResponse
10+
from app.utils.api import get_current_api_user, handle_exception
11+
from app.utils.prisma import prisma
12+
from prisma import Json
13+
14+
SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None)
15+
16+
router = APIRouter()
17+
analytics.write_key = SEGMENT_WRITE_KEY
18+
19+
20+
@router.post(
21+
"/memory-db",
22+
name="create",
23+
description="Create a new Memory Database",
24+
response_model=MemoryDbResponse,
25+
)
26+
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)):
27+
"""Endpoint for creating a Memory Database"""
28+
if SEGMENT_WRITE_KEY:
29+
analytics.track(api_user.id, "Created Memory Database")
30+
31+
data = await prisma.memorydb.create(
32+
{
33+
**body.dict(),
34+
"apiUserId": api_user.id,
35+
"options": json.dumps(body.options),
36+
}
37+
)
38+
data.options = json.dumps(data.options)
39+
return {"success": True, "data": data}
40+
41+
42+
@router.get(
43+
"/memory-dbs",
44+
name="list",
45+
description="List all Memory Databases",
46+
response_model=MemoryDbListResponse,
47+
)
48+
async def list(api_user=Depends(get_current_api_user)):
49+
"""Endpoint for listing all Memory Databases"""
50+
try:
51+
data = await prisma.memorydb.find_many(
52+
where={"apiUserId": api_user.id}, order={"createdAt": "desc"}
53+
)
54+
# Convert options to string
55+
for item in data:
56+
item.options = json.dumps(item.options)
57+
return {"success": True, "data": data}
58+
except Exception as e:
59+
handle_exception(e)
60+
61+
62+
@router.get(
63+
"/memory-dbs/{memory_db_id}",
64+
name="get",
65+
description="Get a single Memory Database",
66+
response_model=MemoryDbResponse,
67+
)
68+
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)):
69+
"""Endpoint for getting a single Memory Database"""
70+
try:
71+
data = await prisma.memorydb.find_first(
72+
where={"id": memory_db_id, "apiUserId": api_user.id}
73+
)
74+
data.options = json.dumps(data.options)
75+
return {"success": True, "data": data}
76+
except Exception as e:
77+
handle_exception(e)
78+
79+
80+
@router.patch(
81+
"/memory-dbs/{memory_db_id}",
82+
name="update",
83+
description="Patch a Memory Database",
84+
response_model=MemoryDbResponse,
85+
)
86+
async def update(
87+
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user)
88+
):
89+
"""Endpoint for patching a Memory Database"""
90+
try:
91+
if SEGMENT_WRITE_KEY:
92+
analytics.track(api_user.id, "Updated Memory Database")
93+
data = await prisma.memorydb.update(
94+
where={"id": memory_db_id},
95+
data={
96+
**body.dict(exclude_unset=True),
97+
"apiUserId": api_user.id,
98+
"options": Json(body.options),
99+
},
100+
)
101+
data.options = json.dumps(data.options)
102+
return {"success": True, "data": data}
103+
except Exception as e:
104+
handle_exception(e)

libs/superagent/app/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
import time
34

45
import colorlog
@@ -26,7 +27,7 @@
2627
console_handler.setFormatter(formatter)
2728

2829
logging.basicConfig(
29-
level=logging.INFO,
30+
level=os.environ.get("LOG_LEVEL", "INFO"),
3031
format="%(levelname)s: %(message)s",
3132
handlers=[console_handler],
3233
force=True,

libs/superagent/app/models/request.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from openai.types.beta.assistant_create_params import Tool as OpenAiAssistantTool
44
from pydantic import BaseModel
55

6-
from prisma.enums import AgentType, LLMProvider, VectorDbProvider
6+
from prisma.enums import AgentType, LLMProvider, MemoryDbProvider, VectorDbProvider
77

88

99
class ApiUser(BaseModel):
@@ -40,6 +40,7 @@ class AgentUpdate(BaseModel):
4040
initialMessage: Optional[str]
4141
prompt: Optional[str]
4242
llmModel: Optional[str]
43+
memory: Optional[str]
4344
description: Optional[str]
4445
avatar: Optional[str]
4546
type: Optional[str]
@@ -132,3 +133,8 @@ class WorkflowInvoke(BaseModel):
132133
class VectorDb(BaseModel):
133134
provider: VectorDbProvider
134135
options: Dict
136+
137+
138+
class MemoryDb(BaseModel):
139+
provider: MemoryDbProvider
140+
options: Dict

libs/superagent/app/models/response.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from prisma.models import (
2121
Datasource as DatasourceModel,
2222
)
23+
from prisma.models import (
24+
MemoryDb as MemoryDbModel,
25+
)
2326
from prisma.models import (
2427
Tool as ToolModel,
2528
)
@@ -141,3 +144,13 @@ class VectorDb(BaseModel):
141144
class VectorDbList(BaseModel):
142145
success: bool
143146
data: Optional[List[VectorDbModel]]
147+
148+
149+
class MemoryDb(BaseModel):
150+
success: bool
151+
data: Optional[MemoryDbModel]
152+
153+
154+
class MemoryDbList(BaseModel):
155+
success: bool
156+
data: Optional[List[MemoryDbModel]]

libs/superagent/app/routers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
api_user,
66
datasources,
77
llms,
8+
memory_dbs,
89
tools,
910
vector_dbs,
1011
workflows,
@@ -24,3 +25,4 @@
2425
workflow_configs.router, tags=["Workflow Config"], prefix=api_prefix
2526
)
2627
router.include_router(vector_dbs.router, tags=["Vector Database"], prefix=api_prefix)
28+
router.include_router(memory_dbs.router, tags=["Memory Database"], prefix=api_prefix)

0 commit comments

Comments
 (0)