Skip to content

Support set Memory from UI #852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/superagent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ ENV PORT="8080"

COPY --from=builder /app/.venv /app/.venv

COPY . ./

# Improve grpc error messages
RUN pip install grpcio-status

COPY . ./
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdqfork Why did you move this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating code and environment dependencies allows for optimal utilization of the cache built by Docker. During the debugging phase, there are more changes in the code compared to changes in the environment dependencies.


# Enable prisma migrations
RUN prisma generate

Expand Down
7 changes: 6 additions & 1 deletion libs/superagent/app/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from app.models.request import LLMParams
from app.utils.callbacks import CustomAsyncIteratorCallbackHandler
from prisma.enums import AgentType
from prisma.models import Agent
from prisma.models import Agent, MemoryDb

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand All @@ -21,6 +21,7 @@ def __init__(
callbacks: List[CustomAsyncIteratorCallbackHandler] = [],
llm_params: Optional[LLMParams] = {},
agent_config: Agent = None,
memory_config: Optional[MemoryDb] = None,
):
self.agent_id = agent_id
self.session_id = session_id
Expand All @@ -29,6 +30,7 @@ def __init__(
self.callbacks = callbacks
self.llm_params = llm_params
self.agent_config = agent_config
self.memory_config = memory_config

async def _get_tools(
self,
Expand Down Expand Up @@ -60,6 +62,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

elif self.agent_config.type == AgentType.LLM:
Expand All @@ -72,6 +75,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

else:
Expand All @@ -85,6 +89,7 @@ async def get_agent(self):
callbacks=self.callbacks,
llm_params=self.llm_params,
agent_config=self.agent_config,
memory_config=self.memory_config,
)

return await agent.get_agent()
Expand Down
55 changes: 43 additions & 12 deletions libs/superagent/app/agents/langchain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import datetime
import json
import logging
import re
from typing import Any, List
from typing import Any, List, Optional

from decouple import config
from langchain.agents import AgentType, initialize_agent
Expand All @@ -23,8 +24,12 @@
from app.models.tools import DatasourceInput
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool
from app.utils.helpers import get_first_non_null
from app.utils.llm import LLM_MAPPING
from prisma.models import LLM, Agent, AgentDatasource, AgentTool
from prisma.enums import LLMProvider, MemoryDbProvider
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb

logger = logging.getLogger(__name__)

DEFAULT_PROMPT = (
"You are a helpful AI Assistant, answer the users questions to "
Expand Down Expand Up @@ -148,7 +153,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
**(self.llm_params.dict() if self.llm_params else {}),
}

if llm.provider == "OPENAI":
if llm.provider == LLMProvider.OPENAI:
return ChatOpenAI(
model=LLM_MAPPING[model],
openai_api_key=llm.apiKey,
Expand All @@ -157,7 +162,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any:
**(llm.options if llm.options else {}),
**(llm_params),
)
elif llm.provider == "AZURE_OPENAI":
elif llm.provider == LLMProvider.AZURE_OPENAI:
return AzureChatOpenAI(
api_key=llm.apiKey,
streaming=self.enable_streaming,
Expand Down Expand Up @@ -193,33 +198,59 @@ async def _get_prompt(self, agent: Agent) -> str:
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}"
return SystemMessage(content=content)

async def _get_memory(self) -> List:
memory_type = config("MEMORY", "motorhead")
if memory_type == "redis":
async def _get_memory(self, memory_db: Optional[MemoryDb]) -> List:
logger.debug(f"Use memory config: {memory_db}")
if memory_db is None:
memory_provider = config("MEMORY", "motorhead")
options = {}
else:
memory_provider = memory_db.provider
options = memory_db.options

memory_provider = memory_provider.upper()
logger.info(f"Using memory provider: {memory_provider}")

if memory_provider == MemoryDbProvider.REDIS:
memory = ConversationBufferWindowMemory(
chat_memory=RedisChatMessageHistory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
url=get_first_non_null(
options.get("REDIS_MEMORY_URL"),
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"),
),
key_prefix="superagent:",
),
memory_key="chat_history",
return_messages=True,
output_key="output",
k=config("REDIS_MEMORY_WINDOW", 10),
k=get_first_non_null(
options.get("REDIS_MEMORY_WINDOW"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one wants to define different REDIS_MEMORY_WINDOW per each agent. We can't achieve it by defining global memory and use it this way. Is it something we should care about? @homanp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elisalimli @homanp Perhaps memory_size could be utilized as an agent argument, representing the number of recent conversations to give priority to. If this value exists, it will be utilized; otherwise, the global configuration will be used.

config("REDIS_MEMORY_WINDOW", 10),
),
)
else:
elif memory_provider == MemoryDbProvider.MOTORHEAD:
url = get_first_non_null(
options.get("MEMORY_API_URL"),
config("MEMORY_API_URL"),
)

if not url:
raise ValueError(
"Memory API URL is required for Motorhead memory provider"
)

memory = MotorheadMemory(
session_id=(
f"{self.agent_id}-{self.session_id}"
if self.session_id
else f"{self.agent_id}"
),
memory_key="chat_history",
url=config("MEMORY_API_URL"),
url=url,
return_messages=True,
output_key="output",
)
Expand All @@ -235,7 +266,7 @@ async def get_agent(self):
agent_tools=self.agent_config.tools,
)
prompt = await self._get_prompt(agent=self.agent_config)
memory = await self._get_memory()
memory = await self._get_memory(memory_db=self.memory_config)

if len(tools) > 0:
agent = initialize_agent(
Expand Down
5 changes: 5 additions & 0 deletions libs/superagent/app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ async def invoke(
if not model and metadata.get("model"):
model = metadata.get("model")

memory_config = await prisma.memorydb.find_first(
where={"provider": agent_config.memory, "apiUserId": api_user.id},
)

def track_agent_invocation(result):
intermediate_steps_to_obj = [
{
Expand Down Expand Up @@ -571,6 +575,7 @@ async def send_message(
callbacks=monitoring_callbacks,
llm_params=body.llm_params,
agent_config=agent_config,
memory_config=memory_config,
)
agent = await agent_base.get_agent()

Expand Down
104 changes: 104 additions & 0 deletions libs/superagent/app/api/memory_dbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json

import segment.analytics as analytics
from decouple import config
from fastapi import APIRouter, Depends

from app.models.request import MemoryDb as MemoryDbRequest
from app.models.response import MemoryDb as MemoryDbResponse
from app.models.response import MemoryDbList as MemoryDbListResponse
from app.utils.api import get_current_api_user, handle_exception
from app.utils.prisma import prisma
from prisma import Json

SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None)

router = APIRouter()
analytics.write_key = SEGMENT_WRITE_KEY


@router.post(
"/memory-db",
name="create",
description="Create a new Memory Database",
response_model=MemoryDbResponse,
)
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)):
"""Endpoint for creating a Memory Database"""
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Created Memory Database")

data = await prisma.memorydb.create(
{
**body.dict(),
"apiUserId": api_user.id,
"options": json.dumps(body.options),
}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}


@router.get(
"/memory-dbs",
name="list",
description="List all Memory Databases",
response_model=MemoryDbListResponse,
)
async def list(api_user=Depends(get_current_api_user)):
"""Endpoint for listing all Memory Databases"""
try:
data = await prisma.memorydb.find_many(
where={"apiUserId": api_user.id}, order={"createdAt": "desc"}
)
# Convert options to string
for item in data:
item.options = json.dumps(item.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.get(
"/memory-dbs/{memory_db_id}",
name="get",
description="Get a single Memory Database",
response_model=MemoryDbResponse,
)
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)):
"""Endpoint for getting a single Memory Database"""
try:
data = await prisma.memorydb.find_first(
where={"id": memory_db_id, "apiUserId": api_user.id}
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)


@router.patch(
"/memory-dbs/{memory_db_id}",
name="update",
description="Patch a Memory Database",
response_model=MemoryDbResponse,
)
async def update(
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user)
):
"""Endpoint for patching a Memory Database"""
try:
if SEGMENT_WRITE_KEY:
analytics.track(api_user.id, "Updated Memory Database")
data = await prisma.memorydb.update(
where={"id": memory_db_id},
data={
**body.dict(exclude_unset=True),
"apiUserId": api_user.id,
"options": Json(body.options),
},
)
data.options = json.dumps(data.options)
return {"success": True, "data": data}
except Exception as e:
handle_exception(e)
3 changes: 2 additions & 1 deletion libs/superagent/app/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
import time

import colorlog
Expand Down Expand Up @@ -26,7 +27,7 @@
console_handler.setFormatter(formatter)

logging.basicConfig(
level=logging.INFO,
level=os.environ.get("LOG_LEVEL", "INFO"),
format="%(levelname)s: %(message)s",
handlers=[console_handler],
force=True,
Expand Down
8 changes: 7 additions & 1 deletion libs/superagent/app/models/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openai.types.beta.assistant_create_params import Tool as OpenAiAssistantTool
from pydantic import BaseModel

from prisma.enums import AgentType, LLMProvider, VectorDbProvider
from prisma.enums import AgentType, LLMProvider, MemoryDbProvider, VectorDbProvider


class ApiUser(BaseModel):
Expand Down Expand Up @@ -40,6 +40,7 @@ class AgentUpdate(BaseModel):
initialMessage: Optional[str]
prompt: Optional[str]
llmModel: Optional[str]
memory: Optional[str]
description: Optional[str]
avatar: Optional[str]
type: Optional[str]
Expand Down Expand Up @@ -132,3 +133,8 @@ class WorkflowInvoke(BaseModel):
class VectorDb(BaseModel):
provider: VectorDbProvider
options: Dict


class MemoryDb(BaseModel):
provider: MemoryDbProvider
options: Dict
13 changes: 13 additions & 0 deletions libs/superagent/app/models/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from prisma.models import (
Datasource as DatasourceModel,
)
from prisma.models import (
MemoryDb as MemoryDbModel,
)
from prisma.models import (
Tool as ToolModel,
)
Expand Down Expand Up @@ -141,3 +144,13 @@ class VectorDb(BaseModel):
class VectorDbList(BaseModel):
success: bool
data: Optional[List[VectorDbModel]]


class MemoryDb(BaseModel):
success: bool
data: Optional[MemoryDbModel]


class MemoryDbList(BaseModel):
success: bool
data: Optional[List[MemoryDbModel]]
2 changes: 2 additions & 0 deletions libs/superagent/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
api_user,
datasources,
llms,
memory_dbs,
tools,
vector_dbs,
workflows,
Expand All @@ -24,3 +25,4 @@
workflow_configs.router, tags=["Workflow Config"], prefix=api_prefix
)
router.include_router(vector_dbs.router, tags=["Vector Database"], prefix=api_prefix)
router.include_router(memory_dbs.router, tags=["Memory Database"], prefix=api_prefix)
Loading