From f4b09616b3bd283a4a4128210562ddce1ca4de4b Mon Sep 17 00:00:00 2001 From: Marcus Schiesser Date: Thu, 14 Dec 2023 16:03:07 +0700 Subject: [PATCH] feat: added mm-rag example and started the response synthesis for it --- examples/multimodal/rag.ts | 43 ++++++++++++++ examples/multimodal/retrieve.ts | 8 +-- packages/core/src/Node.ts | 20 +++++++ packages/core/src/QueryEngine.ts | 17 +++--- packages/core/src/index.ts | 2 +- packages/core/src/indices/BaseIndex.ts | 4 +- .../src/indices/keyword/KeywordTableIndex.ts | 4 +- .../core/src/indices/summary/SummaryIndex.ts | 13 +++-- .../indices/vectorStore/VectorStoreIndex.ts | 28 ++------- packages/core/src/llm/LLM.ts | 3 +- .../MultiModalResponseSynthesizer.ts | 50 ++++++++++++++++ .../src/synthesizers/ResponseSynthesizer.ts | 49 ++++++++++++++++ .../builders.ts} | 57 ++----------------- packages/core/src/synthesizers/index.ts | 4 ++ packages/core/src/synthesizers/types.ts | 15 +++++ .../core/src/tests/CallbackManager.test.ts | 2 +- 16 files changed, 219 insertions(+), 100 deletions(-) create mode 100644 examples/multimodal/rag.ts create mode 100644 packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts create mode 100644 packages/core/src/synthesizers/ResponseSynthesizer.ts rename packages/core/src/{ResponseSynthesizer.ts => synthesizers/builders.ts} (82%) create mode 100644 packages/core/src/synthesizers/index.ts create mode 100644 packages/core/src/synthesizers/types.ts diff --git a/examples/multimodal/rag.ts b/examples/multimodal/rag.ts new file mode 100644 index 0000000000..e726baf82b --- /dev/null +++ b/examples/multimodal/rag.ts @@ -0,0 +1,43 @@ +import { + MultiModalResponseSynthesizer, + OpenAI, + ServiceContext, + VectorStoreIndex, + serviceContextFromDefaults, + storageContextFromDefaults, +} from "llamaindex"; + +export async function createIndex(serviceContext: ServiceContext) { + // set up vector store index with two vector stores, one for text, the other for images + const storageContext = await storageContextFromDefaults({ + persistDir: "storage", + storeImages: true, + }); + return await VectorStoreIndex.init({ + nodes: [], + storageContext, + serviceContext, + }); +} + +async function main() { + const llm = new OpenAI({ model: "gpt-4-vision-preview", maxTokens: 512 }); + const serviceContext = serviceContextFromDefaults({ + llm, + chunkSize: 512, + chunkOverlap: 20, + }); + const index = await createIndex(serviceContext); + + const queryEngine = index.asQueryEngine({ + responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }), + // TODO: set imageSimilarityTopK: 1, + retriever: index.asRetriever({ similarityTopK: 2 }), + }); + const result = await queryEngine.query( + "what are Vincent van Gogh's famous paintings", + ); + console.log(result.response); +} + +main().catch(console.error); diff --git a/examples/multimodal/retrieve.ts b/examples/multimodal/retrieve.ts index 9366935a49..9c768cb5b5 100644 --- a/examples/multimodal/retrieve.ts +++ b/examples/multimodal/retrieve.ts @@ -7,7 +7,7 @@ import { } from "llamaindex"; import * as path from "path"; -export async function createRetriever() { +export async function createIndex() { // set up vector store index with two vector stores, one for text, the other for images const serviceContext = serviceContextFromDefaults({ chunkSize: 512, @@ -17,17 +17,17 @@ export async function createRetriever() { persistDir: "storage", storeImages: true, }); - const index = await VectorStoreIndex.init({ + return await VectorStoreIndex.init({ nodes: [], storageContext, serviceContext, }); - return index.asRetriever({ similarityTopK: 3 }); } async function main() { // retrieve documents using the index - const retriever = await createRetriever(); + const index = await createIndex(); + const retriever = index.asRetriever({ similarityTopK: 3 }); const results = await retriever.retrieve( "what are Vincent van Gogh's famous paintings", ); diff --git a/packages/core/src/Node.ts b/packages/core/src/Node.ts index 67ed91a1c1..a1d6c5e988 100644 --- a/packages/core/src/Node.ts +++ b/packages/core/src/Node.ts @@ -327,3 +327,23 @@ export interface NodeWithScore { node: BaseNode; score?: number; } + +export function splitNodesByType(nodes: BaseNode[]): { + imageNodes: ImageNode[]; + textNodes: TextNode[]; +} { + let imageNodes: ImageNode[] = []; + let textNodes: TextNode[] = []; + + for (let node of nodes) { + if (node instanceof ImageNode) { + imageNodes.push(node); + } else if (node instanceof TextNode) { + textNodes.push(node); + } + } + return { + imageNodes, + textNodes, + }; +} diff --git a/packages/core/src/QueryEngine.ts b/packages/core/src/QueryEngine.ts index abfb52d81c..8e032fbf84 100644 --- a/packages/core/src/QueryEngine.ts +++ b/packages/core/src/QueryEngine.ts @@ -1,6 +1,4 @@ import { v4 as uuidv4 } from "uuid"; -import { Event } from "./callbacks/CallbackManager"; -import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor"; import { NodeWithScore, TextNode } from "./Node"; import { BaseQuestionGenerator, @@ -8,10 +6,13 @@ import { SubQuestion, } from "./QuestionGenerator"; import { Response } from "./Response"; -import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer"; import { BaseRetriever } from "./Retriever"; import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; import { QueryEngineTool, ToolMetadata } from "./Tool"; +import { Event } from "./callbacks/CallbackManager"; +import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor"; +import { CompactAndRefine, ResponseSynthesizer } from "./synthesizers"; +import { BaseSynthesizer } from "./synthesizers/types"; /** * A query engine is a question answerer that can use one or more steps. @@ -30,13 +31,13 @@ export interface BaseQueryEngine { */ export class RetrieverQueryEngine implements BaseQueryEngine { retriever: BaseRetriever; - responseSynthesizer: ResponseSynthesizer; + responseSynthesizer: BaseSynthesizer; nodePostprocessors: BaseNodePostprocessor[]; preFilters?: unknown; constructor( retriever: BaseRetriever, - responseSynthesizer?: ResponseSynthesizer, + responseSynthesizer?: BaseSynthesizer, preFilters?: unknown, nodePostprocessors?: BaseNodePostprocessor[], ) { @@ -81,14 +82,14 @@ export class RetrieverQueryEngine implements BaseQueryEngine { * SubQuestionQueryEngine decomposes a question into subquestions and then */ export class SubQuestionQueryEngine implements BaseQueryEngine { - responseSynthesizer: ResponseSynthesizer; + responseSynthesizer: BaseSynthesizer; questionGen: BaseQuestionGenerator; queryEngines: Record; metadatas: ToolMetadata[]; constructor(init: { questionGen: BaseQuestionGenerator; - responseSynthesizer: ResponseSynthesizer; + responseSynthesizer: BaseSynthesizer; queryEngineTools: QueryEngineTool[]; }) { this.questionGen = init.questionGen; @@ -106,7 +107,7 @@ export class SubQuestionQueryEngine implements BaseQueryEngine { static fromDefaults(init: { queryEngineTools: QueryEngineTool[]; questionGen?: BaseQuestionGenerator; - responseSynthesizer?: ResponseSynthesizer; + responseSynthesizer?: BaseSynthesizer; serviceContext?: ServiceContext; }) { const serviceContext = diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index eacdfc6cc9..d71560dfd1 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -9,7 +9,7 @@ export * from "./PromptHelper"; export * from "./QueryEngine"; export * from "./QuestionGenerator"; export * from "./Response"; -export * from "./ResponseSynthesizer"; +export * from "./synthesizers"; export * from "./Retriever"; export * from "./ServiceContext"; export * from "./TextSplitter"; diff --git a/packages/core/src/indices/BaseIndex.ts b/packages/core/src/indices/BaseIndex.ts index c9c0d66f33..3d44799a74 100644 --- a/packages/core/src/indices/BaseIndex.ts +++ b/packages/core/src/indices/BaseIndex.ts @@ -1,13 +1,13 @@ import { v4 as uuidv4 } from "uuid"; import { BaseNode, Document, jsonToNode } from "../Node"; import { BaseQueryEngine } from "../QueryEngine"; -import { ResponseSynthesizer } from "../ResponseSynthesizer"; import { BaseRetriever } from "../Retriever"; import { ServiceContext } from "../ServiceContext"; import { StorageContext } from "../storage/StorageContext"; import { BaseDocumentStore } from "../storage/docStore/types"; import { BaseIndexStore } from "../storage/indexStore/types"; import { VectorStore } from "../storage/vectorStore/types"; +import { BaseSynthesizer } from "../synthesizers"; /** * The underlying structure of each index. @@ -180,7 +180,7 @@ export abstract class BaseIndex { */ abstract asQueryEngine(options?: { retriever?: BaseRetriever; - responseSynthesizer?: ResponseSynthesizer; + responseSynthesizer?: BaseSynthesizer; }): BaseQueryEngine; /** diff --git a/packages/core/src/indices/keyword/KeywordTableIndex.ts b/packages/core/src/indices/keyword/KeywordTableIndex.ts index 9e5efd8689..406809d88f 100644 --- a/packages/core/src/indices/keyword/KeywordTableIndex.ts +++ b/packages/core/src/indices/keyword/KeywordTableIndex.ts @@ -1,7 +1,6 @@ import { BaseNode, Document, MetadataMode } from "../../Node"; import { defaultKeywordExtractPrompt } from "../../Prompt"; import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine"; -import { ResponseSynthesizer } from "../../ResponseSynthesizer"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext, @@ -9,6 +8,7 @@ import { } from "../../ServiceContext"; import { StorageContext, storageContextFromDefaults } from "../../storage"; import { BaseDocumentStore } from "../../storage/docStore/types"; +import { BaseSynthesizer } from "../../synthesizers"; import { BaseIndex, BaseIndexInit, @@ -129,7 +129,7 @@ export class KeywordTableIndex extends BaseIndex { asQueryEngine(options?: { retriever?: BaseRetriever; - responseSynthesizer?: ResponseSynthesizer; + responseSynthesizer?: BaseSynthesizer; preFilters?: unknown; nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { diff --git a/packages/core/src/indices/summary/SummaryIndex.ts b/packages/core/src/indices/summary/SummaryIndex.ts index 57659e9ee5..eb6f753330 100644 --- a/packages/core/src/indices/summary/SummaryIndex.ts +++ b/packages/core/src/indices/summary/SummaryIndex.ts @@ -1,20 +1,21 @@ import _ from "lodash"; import { BaseNode, Document } from "../../Node"; import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine"; -import { - CompactAndRefine, - ResponseSynthesizer, -} from "../../ResponseSynthesizer"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext, serviceContextFromDefaults, } from "../../ServiceContext"; -import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; import { StorageContext, storageContextFromDefaults, } from "../../storage/StorageContext"; +import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types"; +import { + BaseSynthesizer, + CompactAndRefine, + ResponseSynthesizer, +} from "../../synthesizers"; import { BaseIndex, BaseIndexInit, @@ -155,7 +156,7 @@ export class SummaryIndex extends BaseIndex { asQueryEngine(options?: { retriever?: BaseRetriever; - responseSynthesizer?: ResponseSynthesizer; + responseSynthesizer?: BaseSynthesizer; preFilters?: unknown; nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { diff --git a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts index eb66b9c7c8..2b818f0063 100644 --- a/packages/core/src/indices/vectorStore/VectorStoreIndex.ts +++ b/packages/core/src/indices/vectorStore/VectorStoreIndex.ts @@ -4,11 +4,10 @@ import { ImageNode, MetadataMode, ObjectType, - TextNode, jsonToNode, + splitNodesByType, } from "../../Node"; import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine"; -import { ResponseSynthesizer } from "../../ResponseSynthesizer"; import { BaseRetriever } from "../../Retriever"; import { ServiceContext, @@ -25,6 +24,7 @@ import { } from "../../storage/StorageContext"; import { BaseIndexStore } from "../../storage/indexStore/types"; import { VectorStore } from "../../storage/vectorStore/types"; +import { ResponseSynthesizer, BaseSynthesizer } from "../../synthesizers"; import { BaseIndex, BaseIndexInit, @@ -248,7 +248,7 @@ export class VectorStoreIndex extends BaseIndex { asQueryEngine(options?: { retriever?: BaseRetriever; - responseSynthesizer?: ResponseSynthesizer; + responseSynthesizer?: BaseSynthesizer; preFilters?: unknown; nodePostprocessors?: BaseNodePostprocessor[]; }): BaseQueryEngine { @@ -290,7 +290,7 @@ export class VectorStoreIndex extends BaseIndex { if (!nodes || nodes.length === 0) { return; } - const { imageNodes, textNodes } = this.splitNodes(nodes); + const { imageNodes, textNodes } = splitNodesByType(nodes); if (imageNodes.length > 0) { if (!this.imageVectorStore) { throw new Error("Cannot insert image nodes without image vector store"); @@ -368,24 +368,4 @@ export class VectorStoreIndex extends BaseIndex { return nodesWithEmbeddings; } - - private splitNodes(nodes: BaseNode[]): { - imageNodes: ImageNode[]; - textNodes: TextNode[]; - } { - let imageNodes: ImageNode[] = []; - let textNodes: TextNode[] = []; - - for (let node of nodes) { - if (node instanceof ImageNode) { - imageNodes.push(node); - } else if (node instanceof TextNode) { - textNodes.push(node); - } - } - return { - imageNodes, - textNodes, - }; - } } diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index 95442beb63..7964f5bd55 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -27,6 +27,7 @@ import { import { getOpenAISession, OpenAISession } from "./openai"; import { getPortkeySession, PortkeySession } from "./portkey"; import { ReplicateSession } from "./replicate"; +import { MessageContent } from "../ChatEngine"; export type MessageType = | "user" @@ -89,7 +90,7 @@ export interface LLM { T extends boolean | undefined = undefined, R = T extends true ? AsyncGenerator : ChatResponse, >( - prompt: string, + prompt: MessageContent, parentEvent?: Event, streaming?: T, ): Promise; diff --git a/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts new file mode 100644 index 0000000000..eeb27fcd5f --- /dev/null +++ b/packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts @@ -0,0 +1,50 @@ +import { MessageContentDetail } from "../ChatEngine"; +import { MetadataMode, NodeWithScore, splitNodesByType } from "../Node"; +import { Response } from "../Response"; +import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; +import { Event } from "../callbacks/CallbackManager"; +import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt"; +import { BaseSynthesizer } from "./types"; + +export class MultiModalResponseSynthesizer implements BaseSynthesizer { + serviceContext: ServiceContext; + metadataMode: MetadataMode; + textQATemplate: TextQaPrompt; + + constructor({ + serviceContext, + textQATemplate, + metadataMode, + }: Partial = {}) { + this.serviceContext = serviceContext ?? serviceContextFromDefaults(); + this.metadataMode = metadataMode ?? MetadataMode.NONE; + this.textQATemplate = textQATemplate ?? defaultTextQaPrompt; + } + + async synthesize( + query: string, + nodesWithScore: NodeWithScore[], + parentEvent?: Event, + ): Promise { + const nodes = nodesWithScore.map(({ node }) => node); + const { imageNodes, textNodes } = splitNodesByType(nodes); + const textChunks = textNodes.map((node) => + node.getContent(this.metadataMode), + ); + // TODO: use builders to generate context + const context = textChunks.join("\n\n"); + const textPrompt = this.textQATemplate({ context, query }); + // TODO: get images from imageNodes + const prompt: MessageContentDetail[] = [ + { type: "text", text: textPrompt }, + { + type: "image_url", + image_url: { + url: "https://upload.wikimedia.org/wikipedia/commons/b/b0/Vincent_van_Gogh_%281853-1890%29_Caf%C3%A9terras_bij_nacht_%28place_du_Forum%29_Kr%C3%B6ller-M%C3%BCller_Museum_Otterlo_23-8-2016_13-35-40.JPG", + }, + }, + ]; + let response = await this.serviceContext.llm.complete(prompt, parentEvent); + return new Response(response.message.content, nodes); + } +} diff --git a/packages/core/src/synthesizers/ResponseSynthesizer.ts b/packages/core/src/synthesizers/ResponseSynthesizer.ts new file mode 100644 index 0000000000..4c290683b3 --- /dev/null +++ b/packages/core/src/synthesizers/ResponseSynthesizer.ts @@ -0,0 +1,49 @@ +import { Event } from "../callbacks/CallbackManager"; +import { MetadataMode, NodeWithScore } from "../Node"; +import { Response } from "../Response"; +import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; +import { BaseResponseBuilder, getResponseBuilder } from "./builders"; +import { BaseSynthesizer } from "./types"; + +/** + * A ResponseSynthesizer is used to generate a response from a query and a list of nodes. + */ +export class ResponseSynthesizer implements BaseSynthesizer { + responseBuilder: BaseResponseBuilder; + serviceContext: ServiceContext; + metadataMode: MetadataMode; + + constructor({ + responseBuilder, + serviceContext, + metadataMode = MetadataMode.NONE, + }: { + responseBuilder?: BaseResponseBuilder; + serviceContext?: ServiceContext; + metadataMode?: MetadataMode; + } = {}) { + this.serviceContext = serviceContext ?? serviceContextFromDefaults(); + this.responseBuilder = + responseBuilder ?? getResponseBuilder(this.serviceContext); + this.metadataMode = metadataMode; + } + + async synthesize( + query: string, + nodesWithScore: NodeWithScore[], + parentEvent?: Event, + ) { + let textChunks: string[] = nodesWithScore.map(({ node }) => + node.getContent(this.metadataMode), + ); + const response = await this.responseBuilder.getResponse( + query, + textChunks, + parentEvent, + ); + return new Response( + response, + nodesWithScore.map(({ node }) => node), + ); + } +} diff --git a/packages/core/src/ResponseSynthesizer.ts b/packages/core/src/synthesizers/builders.ts similarity index 82% rename from packages/core/src/ResponseSynthesizer.ts rename to packages/core/src/synthesizers/builders.ts index f3151a6dcc..33c8aac562 100644 --- a/packages/core/src/ResponseSynthesizer.ts +++ b/packages/core/src/synthesizers/builders.ts @@ -1,6 +1,5 @@ -import { Event } from "./callbacks/CallbackManager"; -import { LLM } from "./llm/LLM"; -import { MetadataMode, NodeWithScore } from "./Node"; +import { Event } from "../callbacks/CallbackManager"; +import { LLM } from "../llm/LLM"; import { defaultRefinePrompt, defaultTextQaPrompt, @@ -9,10 +8,9 @@ import { SimplePrompt, TextQaPrompt, TreeSummarizePrompt, -} from "./Prompt"; -import { getBiggestPrompt } from "./PromptHelper"; -import { Response } from "./Response"; -import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext"; +} from "../Prompt"; +import { getBiggestPrompt } from "../PromptHelper"; +import { ServiceContext } from "../ServiceContext"; /** * Response modes of the response synthesizer @@ -27,7 +25,7 @@ enum ResponseMode { /** * A ResponseBuilder is used in a response synthesizer to generate a response from multiple response chunks. */ -interface BaseResponseBuilder { +export interface BaseResponseBuilder { /** * Get the response from a query and a list of text chunks. * @param query @@ -283,46 +281,3 @@ export function getResponseBuilder( return new CompactAndRefine(serviceContext); } } - -/** - * A ResponseSynthesizer is used to generate a response from a query and a list of nodes. - */ -export class ResponseSynthesizer { - responseBuilder: BaseResponseBuilder; - serviceContext: ServiceContext; - metadataMode: MetadataMode; - - constructor({ - responseBuilder, - serviceContext, - metadataMode = MetadataMode.NONE, - }: { - responseBuilder?: BaseResponseBuilder; - serviceContext?: ServiceContext; - metadataMode?: MetadataMode; - } = {}) { - this.serviceContext = serviceContext ?? serviceContextFromDefaults(); - this.responseBuilder = - responseBuilder ?? getResponseBuilder(this.serviceContext); - this.metadataMode = metadataMode; - } - - async synthesize( - query: string, - nodesWithScore: NodeWithScore[], - parentEvent?: Event, - ) { - let textChunks: string[] = nodesWithScore.map(({ node }) => - node.getContent(this.metadataMode), - ); - const response = await this.responseBuilder.getResponse( - query, - textChunks, - parentEvent, - ); - return new Response( - response, - nodesWithScore.map(({ node }) => node), - ); - } -} diff --git a/packages/core/src/synthesizers/index.ts b/packages/core/src/synthesizers/index.ts new file mode 100644 index 0000000000..2ec58d918d --- /dev/null +++ b/packages/core/src/synthesizers/index.ts @@ -0,0 +1,4 @@ +export * from "./MultiModalResponseSynthesizer"; +export * from "./ResponseSynthesizer"; +export * from "./builders"; +export * from "./types"; diff --git a/packages/core/src/synthesizers/types.ts b/packages/core/src/synthesizers/types.ts new file mode 100644 index 0000000000..c465db871f --- /dev/null +++ b/packages/core/src/synthesizers/types.ts @@ -0,0 +1,15 @@ +import { Event } from "../callbacks/CallbackManager"; +import { NodeWithScore } from "../Node"; +import { Response } from "../Response"; + +/** + * A BaseSynthesizer is used to generate a response from a query and a list of nodes. + * TODO: convert response builders to implement this interface (similar to Python). + */ +export interface BaseSynthesizer { + synthesize( + query: string, + nodesWithScore: NodeWithScore[], + parentEvent?: Event, + ): Promise; +} diff --git a/packages/core/src/tests/CallbackManager.test.ts b/packages/core/src/tests/CallbackManager.test.ts index 8d1e1648a9..f9ef370b46 100644 --- a/packages/core/src/tests/CallbackManager.test.ts +++ b/packages/core/src/tests/CallbackManager.test.ts @@ -11,7 +11,7 @@ import { Document } from "../Node"; import { ResponseSynthesizer, SimpleResponseBuilder, -} from "../ResponseSynthesizer"; +} from "../synthesizers"; import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext"; import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI";