-
Notifications
You must be signed in to change notification settings - Fork 897
Support set Memory from UI #852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
import datetime | ||
import json | ||
import logging | ||
import re | ||
from typing import Any, List | ||
from typing import Any, List, Optional | ||
|
||
from decouple import config | ||
from langchain.agents import AgentType, initialize_agent | ||
|
@@ -23,8 +24,12 @@ | |
from app.models.tools import DatasourceInput | ||
from app.tools import TOOL_TYPE_MAPPING, create_pydantic_model_from_object, create_tool | ||
from app.tools.datasource import DatasourceTool, StructuredDatasourceTool | ||
from app.utils.helpers import get_first_non_null | ||
from app.utils.llm import LLM_MAPPING | ||
from prisma.models import LLM, Agent, AgentDatasource, AgentTool | ||
from prisma.enums import LLMProvider, MemoryDbProvider | ||
from prisma.models import LLM, Agent, AgentDatasource, AgentTool, MemoryDb | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
DEFAULT_PROMPT = ( | ||
"You are a helpful AI Assistant, answer the users questions to " | ||
|
@@ -148,7 +153,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any: | |
**(self.llm_params.dict() if self.llm_params else {}), | ||
} | ||
|
||
if llm.provider == "OPENAI": | ||
if llm.provider == LLMProvider.OPENAI: | ||
return ChatOpenAI( | ||
model=LLM_MAPPING[model], | ||
openai_api_key=llm.apiKey, | ||
|
@@ -157,7 +162,7 @@ async def _get_llm(self, llm: LLM, model: str) -> Any: | |
**(llm.options if llm.options else {}), | ||
**(llm_params), | ||
) | ||
elif llm.provider == "AZURE_OPENAI": | ||
elif llm.provider == LLMProvider.AZURE_OPENAI: | ||
return AzureChatOpenAI( | ||
api_key=llm.apiKey, | ||
streaming=self.enable_streaming, | ||
|
@@ -193,33 +198,59 @@ async def _get_prompt(self, agent: Agent) -> str: | |
content = f"{content}" f"\n\n{datetime.datetime.now().strftime('%Y-%m-%d')}" | ||
return SystemMessage(content=content) | ||
|
||
async def _get_memory(self) -> List: | ||
memory_type = config("MEMORY", "motorhead") | ||
if memory_type == "redis": | ||
async def _get_memory(self, memory_db: Optional[MemoryDb]) -> List: | ||
logger.debug(f"Use memory config: {memory_db}") | ||
if memory_db is None: | ||
memory_provider = config("MEMORY", "motorhead") | ||
options = {} | ||
else: | ||
memory_provider = memory_db.provider | ||
options = memory_db.options | ||
|
||
memory_provider = memory_provider.upper() | ||
logger.info(f"Using memory provider: {memory_provider}") | ||
|
||
if memory_provider == MemoryDbProvider.REDIS: | ||
memory = ConversationBufferWindowMemory( | ||
chat_memory=RedisChatMessageHistory( | ||
session_id=( | ||
f"{self.agent_id}-{self.session_id}" | ||
if self.session_id | ||
else f"{self.agent_id}" | ||
), | ||
url=config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), | ||
url=get_first_non_null( | ||
options.get("REDIS_MEMORY_URL"), | ||
config("REDIS_MEMORY_URL", "redis://localhost:6379/0"), | ||
), | ||
key_prefix="superagent:", | ||
), | ||
memory_key="chat_history", | ||
return_messages=True, | ||
output_key="output", | ||
k=config("REDIS_MEMORY_WINDOW", 10), | ||
k=get_first_non_null( | ||
options.get("REDIS_MEMORY_WINDOW"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If one wants to define different REDIS_MEMORY_WINDOW per each agent. We can't achieve it by defining global memory and use it this way. Is it something we should care about? @homanp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @elisalimli @homanp Perhaps |
||
config("REDIS_MEMORY_WINDOW", 10), | ||
), | ||
) | ||
else: | ||
elif memory_provider == MemoryDbProvider.MOTORHEAD: | ||
url = get_first_non_null( | ||
options.get("MEMORY_API_URL"), | ||
config("MEMORY_API_URL"), | ||
) | ||
|
||
if not url: | ||
raise ValueError( | ||
"Memory API URL is required for Motorhead memory provider" | ||
) | ||
|
||
memory = MotorheadMemory( | ||
session_id=( | ||
f"{self.agent_id}-{self.session_id}" | ||
if self.session_id | ||
else f"{self.agent_id}" | ||
), | ||
memory_key="chat_history", | ||
url=config("MEMORY_API_URL"), | ||
url=url, | ||
return_messages=True, | ||
output_key="output", | ||
) | ||
|
@@ -235,7 +266,7 @@ async def get_agent(self): | |
agent_tools=self.agent_config.tools, | ||
) | ||
prompt = await self._get_prompt(agent=self.agent_config) | ||
memory = await self._get_memory() | ||
memory = await self._get_memory(memory_db=self.memory_config) | ||
|
||
if len(tools) > 0: | ||
agent = initialize_agent( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import json | ||
|
||
import segment.analytics as analytics | ||
from decouple import config | ||
from fastapi import APIRouter, Depends | ||
|
||
from app.models.request import MemoryDb as MemoryDbRequest | ||
from app.models.response import MemoryDb as MemoryDbResponse | ||
from app.models.response import MemoryDbList as MemoryDbListResponse | ||
from app.utils.api import get_current_api_user, handle_exception | ||
from app.utils.prisma import prisma | ||
from prisma import Json | ||
|
||
SEGMENT_WRITE_KEY = config("SEGMENT_WRITE_KEY", None) | ||
|
||
router = APIRouter() | ||
analytics.write_key = SEGMENT_WRITE_KEY | ||
|
||
|
||
@router.post( | ||
"/memory-db", | ||
name="create", | ||
description="Create a new Memory Database", | ||
response_model=MemoryDbResponse, | ||
) | ||
async def create(body: MemoryDbRequest, api_user=Depends(get_current_api_user)): | ||
"""Endpoint for creating a Memory Database""" | ||
if SEGMENT_WRITE_KEY: | ||
analytics.track(api_user.id, "Created Memory Database") | ||
|
||
data = await prisma.memorydb.create( | ||
{ | ||
**body.dict(), | ||
"apiUserId": api_user.id, | ||
"options": json.dumps(body.options), | ||
} | ||
) | ||
data.options = json.dumps(data.options) | ||
return {"success": True, "data": data} | ||
|
||
|
||
@router.get( | ||
"/memory-dbs", | ||
name="list", | ||
description="List all Memory Databases", | ||
response_model=MemoryDbListResponse, | ||
) | ||
async def list(api_user=Depends(get_current_api_user)): | ||
"""Endpoint for listing all Memory Databases""" | ||
try: | ||
data = await prisma.memorydb.find_many( | ||
where={"apiUserId": api_user.id}, order={"createdAt": "desc"} | ||
) | ||
# Convert options to string | ||
for item in data: | ||
item.options = json.dumps(item.options) | ||
return {"success": True, "data": data} | ||
except Exception as e: | ||
handle_exception(e) | ||
|
||
|
||
@router.get( | ||
"/memory-dbs/{memory_db_id}", | ||
name="get", | ||
description="Get a single Memory Database", | ||
response_model=MemoryDbResponse, | ||
) | ||
async def get(memory_db_id: str, api_user=Depends(get_current_api_user)): | ||
"""Endpoint for getting a single Memory Database""" | ||
try: | ||
data = await prisma.memorydb.find_first( | ||
where={"id": memory_db_id, "apiUserId": api_user.id} | ||
) | ||
data.options = json.dumps(data.options) | ||
return {"success": True, "data": data} | ||
except Exception as e: | ||
handle_exception(e) | ||
|
||
|
||
@router.patch( | ||
"/memory-dbs/{memory_db_id}", | ||
name="update", | ||
description="Patch a Memory Database", | ||
response_model=MemoryDbResponse, | ||
) | ||
async def update( | ||
memory_db_id: str, body: MemoryDbRequest, api_user=Depends(get_current_api_user) | ||
): | ||
"""Endpoint for patching a Memory Database""" | ||
try: | ||
if SEGMENT_WRITE_KEY: | ||
analytics.track(api_user.id, "Updated Memory Database") | ||
data = await prisma.memorydb.update( | ||
where={"id": memory_db_id}, | ||
data={ | ||
**body.dict(exclude_unset=True), | ||
"apiUserId": api_user.id, | ||
"options": Json(body.options), | ||
}, | ||
) | ||
data.options = json.dumps(data.options) | ||
return {"success": True, "data": data} | ||
except Exception as e: | ||
handle_exception(e) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdqfork Why did you move this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separating code and environment dependencies allows for optimal utilization of the cache built by Docker. During the debugging phase, there are more changes in the code compared to changes in the environment dependencies.