Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ft/memorydb for ask route #20

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion client/.env.example
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
BRIAN_API_KEY = ""
DATABASE_URL=""
RPC_URL =''
BOT_USERNAME=
MY_TOKEN = ''
244 changes: 175 additions & 69 deletions client/app/api/ask/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,80 @@ import { NextResponse } from "next/server";
import { ASK_OPENAI_AGENT_PROMPT } from "@/prompts/prompts";
import axios from "axios";
import { ChatOpenAI } from "@langchain/openai";
import {
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
} from "@langchain/core/prompts";
import {
START,
END,
MessagesAnnotation,
MemorySaver,
StateGraph,
} from "@langchain/langgraph";
import { RemoveMessage } from "@langchain/core/messages";

import { ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate } from "@langchain/core/prompts";
import { START, END, MessagesAnnotation, MemorySaver, StateGraph } from "@langchain/langgraph";
// import { RemoveMessage } from "@langchain/core/messages";
import prisma from '@/lib/db';

const BRIAN_API_KEY = process.env.BRIAN_API_KEY || "";
const OPENAI_API_KEY = process.env.OPENAI_API_KEY || "";
const BRIAN_API_URL = "https://api.brianknows.org/api/v0/agent/knowledge";
const BRIAN_DEFAULT_RESPONSE: string =
"🤖 Sorry, I don’t know how to answer. The AskBrian feature allows you to ask for information on a custom-built knowledge base of resources. Contact the Brian team if you want to add new resources!";

async function getOrCreateUser(address: string) {
try {
let user = await prisma.user.findUnique({
where: { id: address },
});

if (!user) {
user = await prisma.user.create({
data: {
id: address,
email: null,
name: null,
},
});
}

return user;
} catch (error) {
console.error('Error in getOrCreateUser:', error);
throw error;
}
}

async function storeMessage({
content,
chatId,
userId,
}: {
content: any[];
chatId: string;
userId: string;
}) {
try {
const message = await prisma.message.create({
data: {
content,
chatId,
userId,
},
});
return message;
} catch (error) {
console.error('Error storing message:', error);
throw error;
}
}

async function createOrGetChat(userId: string) {
try {
await getOrCreateUser(userId);

const chat = await prisma.chat.create({
data: {
userId,
},
});
return chat;
} catch (error) {
console.error('Error creating chat:', error);
throw error;
}
}

const systemPrompt =
ASK_OPENAI_AGENT_PROMPT +
`\nThe provided chat history includes a summary of the earlier conversation.`;
Expand All @@ -49,6 +102,41 @@ const agent = new ChatOpenAI({
});
const prompt = askAgentPromptTemplate;
// const chain = prompt.pipe(agent);
async function getChatHistory(chatId: string | { configurable?: { additional_args?: { chatId?: string } } }) {
try {
const actualChatId = typeof chatId === 'object' && chatId.configurable?.additional_args?.chatId
? chatId.configurable.additional_args.chatId
: chatId;

if (!actualChatId || typeof actualChatId !== 'string') {
console.warn('Invalid chat ID provided:', chatId);
return [];
}

const messages = await prisma.message.findMany({
where: {
chatId: actualChatId
},
orderBy: {
id: 'asc'
}
});

const formattedHistory = messages.flatMap(msg => {
const content = msg.content as any[];
return content.map(c => ({
role: c.role,
content: c.content
}));
});

return formattedHistory;
} catch (error) {
console.error('Error fetching chat history:', error);
return [];
}
}

const initialCallModel = async (state: typeof MessagesAnnotation.State) => {
const messages = [
await systemMessage.format({ brianai_answer: BRIAN_DEFAULT_RESPONSE }),
Expand All @@ -57,32 +145,35 @@ const initialCallModel = async (state: typeof MessagesAnnotation.State) => {
const response = await agent.invoke(messages);
return { messages: response };
};
const callModel = async (state: typeof MessagesAnnotation.State) => {
const messageHistory = state.messages.slice(0, -1);
if (messageHistory.length >= 3) {
const lastHumanMessage = state.messages[state.messages.length - 1];

const callModel = async (state: typeof MessagesAnnotation.State, chatId?: any) => {
if (!chatId) {
return await initialCallModel(state);
}
const actualChatId = chatId?.configurable?.additional_args?.chatId || chatId;
const chatHistory = await getChatHistory(actualChatId);
const currentMessage = state.messages[state.messages.length - 1];

if (chatHistory.length > 0) {
const summaryPrompt = `
Distill the above chat messages into a single summary message.
Distill the following chat history into a single summary message.
Include as many specific details as you can.
IMPORTANT NOTE: Include all information related to user's nature about trading and what kind of trader he/she is.
`;
// const summaryMessage = HumanMessagePromptTemplate.fromTemplate([summaryPrompt]);

const summary = await agent.invoke([
...messageHistory,
...chatHistory,
{ role: "user", content: summaryPrompt },
]);
const deleteMessages = state.messages.map((m) =>
m.id ? new RemoveMessage({ id: m.id }) : null
);
const humanMessage = { role: "user", content: lastHumanMessage.content };

const response = await agent.invoke([
await systemMessage.format({ brianai_answer: BRIAN_DEFAULT_RESPONSE }),
summary,
humanMessage,
currentMessage,
]);
//console.log(response);

return {
messages: [summary, humanMessage, response, ...deleteMessages],
messages: [summary, currentMessage, response],
};
} else {
return await initialCallModel(state);
Expand All @@ -96,11 +187,13 @@ const workflow = new StateGraph(MessagesAnnotation)
const app = workflow.compile({ checkpointer: new MemorySaver() });

async function queryOpenAI({
userQuery,
brianaiResponse,
userQuery,
brianaiResponse,
chatId
}: {
userQuery: string;
brianaiResponse: string;
userQuery: string,
brianaiResponse: string,
chatId?: string
}): Promise<string> {
try {
const response = await app.invoke(
Expand All @@ -113,18 +206,21 @@ async function queryOpenAI({
],
},
{
configurable: { thread_id: "1" },
}
configurable: {
thread_id: chatId || "1",
additional_args: { chatId }
},
},
);
console.log(response);
return response.messages[response.messages.length - 1].content as string;
return response.messages[response.messages.length-1].content as string;
} catch (error) {
console.error("OpenAI Error:", error);
return "Sorry, I am unable to process your request at the moment.";
}
}

async function queryBrianAI(prompt: string): Promise<string> {

async function queryBrianAI(prompt: string, chatId?: string): Promise<string> {
try {
const response = await axios.post(
BRIAN_API_URL,
Expand All @@ -141,8 +237,9 @@ async function queryBrianAI(prompt: string): Promise<string> {
);
const brianaiAnswer = response.data.result.answer;
const openaiAnswer = await queryOpenAI({
brianaiResponse: brianaiAnswer,
brianaiResponse: brianaiAnswer,
userQuery: prompt,
chatId
});
return openaiAnswer;
} catch (error) {
Expand All @@ -153,54 +250,63 @@ async function queryBrianAI(prompt: string): Promise<string> {

export async function POST(request: Request) {
try {
const { prompt, address, messages } = await request.json();
const { prompt, address, messages, chatId } = await request.json();
const userId = address || "0x0";
await getOrCreateUser(userId);

let currentChatId = chatId;
if (!currentChatId) {
const newChat = await createOrGetChat(userId);
currentChatId = newChat.id;
}

// Filter out duplicate messages and only keep user messages
const uniqueMessages = messages
.filter((msg: any) => msg.sender === "user")
.reduce((acc: any[], curr: any) => {
// Only add if message content isn't already present
if (!acc.some((msg) => msg.content === curr.content)) {
if (!acc.some(msg => msg.content === curr.content)) {
acc.push({
sender: "user",
content: curr.content,
role: "user",
content: curr.content
});
}
return acc;
}, []);

const payload = {
prompt,
address: address || "0x0",
chainId: "4012",
messages: uniqueMessages,
};

console.log("Request payload:", JSON.stringify(payload, null, 2));

const response = await queryBrianAI(payload.prompt);

console.log("API Response:", response);
await storeMessage({
content: uniqueMessages,
chatId: currentChatId,
userId,
});

// Extract the answer from the result array
const response = await queryBrianAI(prompt, currentChatId);
if (response) {
return NextResponse.json({ answer: response });
await storeMessage({
content: [{
role: "assistant",
content: response
}],
chatId: currentChatId,
userId,
});

return NextResponse.json({
answer: response,
chatId: currentChatId
});
} else {
throw new Error("Unexpected API response format");
}
} catch (error: any) {
console.error("Detailed error:", {
message: error.message,
response: error.response?.data,
status: error.response?.status,
});

console.error('Error:', error);
if (error.code === 'P2003') {
return NextResponse.json(
{ error: 'User authentication required', details: 'Please ensure you are logged in.' },
{ status: 401 }
);
}
return NextResponse.json(
{
error: "Unable to get response from Brian's API",
details: error.response?.data || error.message,
},
{ status: error.response?.status || 500 }
{ error: 'Unable to process request', details: error.message },
{ status: 500 }
);
}
}
Loading