Skip to content

Commit

Permalink
feat: added mm-rag example and started the response synthesis for it
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusschiesser committed Dec 18, 2023
1 parent 76f7256 commit f4b0961
Show file tree
Hide file tree
Showing 16 changed files with 219 additions and 100 deletions.
43 changes: 43 additions & 0 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import {
MultiModalResponseSynthesizer,
OpenAI,
ServiceContext,
VectorStoreIndex,
serviceContextFromDefaults,
storageContextFromDefaults,
} from "llamaindex";

export async function createIndex(serviceContext: ServiceContext) {
// set up vector store index with two vector stores, one for text, the other for images
const storageContext = await storageContextFromDefaults({
persistDir: "storage",
storeImages: true,
});
return await VectorStoreIndex.init({
nodes: [],
storageContext,
serviceContext,
});
}

async function main() {
const llm = new OpenAI({ model: "gpt-4-vision-preview", maxTokens: 512 });
const serviceContext = serviceContextFromDefaults({
llm,
chunkSize: 512,
chunkOverlap: 20,
});
const index = await createIndex(serviceContext);

const queryEngine = index.asQueryEngine({
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
// TODO: set imageSimilarityTopK: 1,
retriever: index.asRetriever({ similarityTopK: 2 }),
});
const result = await queryEngine.query(
"what are Vincent van Gogh's famous paintings",
);
console.log(result.response);
}

main().catch(console.error);
8 changes: 4 additions & 4 deletions examples/multimodal/retrieve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
} from "llamaindex";
import * as path from "path";

export async function createRetriever() {
export async function createIndex() {
// set up vector store index with two vector stores, one for text, the other for images
const serviceContext = serviceContextFromDefaults({
chunkSize: 512,
Expand All @@ -17,17 +17,17 @@ export async function createRetriever() {
persistDir: "storage",
storeImages: true,
});
const index = await VectorStoreIndex.init({
return await VectorStoreIndex.init({
nodes: [],
storageContext,
serviceContext,
});
return index.asRetriever({ similarityTopK: 3 });
}

async function main() {
// retrieve documents using the index
const retriever = await createRetriever();
const index = await createIndex();
const retriever = index.asRetriever({ similarityTopK: 3 });
const results = await retriever.retrieve(
"what are Vincent van Gogh's famous paintings",
);
Expand Down
20 changes: 20 additions & 0 deletions packages/core/src/Node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,23 @@ export interface NodeWithScore<T extends Metadata = Metadata> {
node: BaseNode<T>;
score?: number;
}

export function splitNodesByType(nodes: BaseNode[]): {
imageNodes: ImageNode[];
textNodes: TextNode[];
} {
let imageNodes: ImageNode[] = [];
let textNodes: TextNode[] = [];

for (let node of nodes) {
if (node instanceof ImageNode) {
imageNodes.push(node);
} else if (node instanceof TextNode) {
textNodes.push(node);
}
}
return {
imageNodes,
textNodes,
};
}
17 changes: 9 additions & 8 deletions packages/core/src/QueryEngine.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
import { NodeWithScore, TextNode } from "./Node";
import {
BaseQuestionGenerator,
LLMQuestionGenerator,
SubQuestion,
} from "./QuestionGenerator";
import { Response } from "./Response";
import { CompactAndRefine, ResponseSynthesizer } from "./ResponseSynthesizer";
import { BaseRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { QueryEngineTool, ToolMetadata } from "./Tool";
import { Event } from "./callbacks/CallbackManager";
import { BaseNodePostprocessor } from "./indices/BaseNodePostprocessor";
import { CompactAndRefine, ResponseSynthesizer } from "./synthesizers";
import { BaseSynthesizer } from "./synthesizers/types";

/**
* A query engine is a question answerer that can use one or more steps.
Expand All @@ -30,13 +31,13 @@ export interface BaseQueryEngine {
*/
export class RetrieverQueryEngine implements BaseQueryEngine {
retriever: BaseRetriever;
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
nodePostprocessors: BaseNodePostprocessor[];
preFilters?: unknown;

constructor(
retriever: BaseRetriever,
responseSynthesizer?: ResponseSynthesizer,
responseSynthesizer?: BaseSynthesizer,
preFilters?: unknown,
nodePostprocessors?: BaseNodePostprocessor[],
) {
Expand Down Expand Up @@ -81,14 +82,14 @@ export class RetrieverQueryEngine implements BaseQueryEngine {
* SubQuestionQueryEngine decomposes a question into subquestions and then
*/
export class SubQuestionQueryEngine implements BaseQueryEngine {
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
questionGen: BaseQuestionGenerator;
queryEngines: Record<string, BaseQueryEngine>;
metadatas: ToolMetadata[];

constructor(init: {
questionGen: BaseQuestionGenerator;
responseSynthesizer: ResponseSynthesizer;
responseSynthesizer: BaseSynthesizer;
queryEngineTools: QueryEngineTool[];
}) {
this.questionGen = init.questionGen;
Expand All @@ -106,7 +107,7 @@ export class SubQuestionQueryEngine implements BaseQueryEngine {
static fromDefaults(init: {
queryEngineTools: QueryEngineTool[];
questionGen?: BaseQuestionGenerator;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
serviceContext?: ServiceContext;
}) {
const serviceContext =
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export * from "./PromptHelper";
export * from "./QueryEngine";
export * from "./QuestionGenerator";
export * from "./Response";
export * from "./ResponseSynthesizer";
export * from "./synthesizers";
export * from "./Retriever";
export * from "./ServiceContext";
export * from "./TextSplitter";
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/indices/BaseIndex.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import { v4 as uuidv4 } from "uuid";
import { BaseNode, Document, jsonToNode } from "../Node";
import { BaseQueryEngine } from "../QueryEngine";
import { ResponseSynthesizer } from "../ResponseSynthesizer";
import { BaseRetriever } from "../Retriever";
import { ServiceContext } from "../ServiceContext";
import { StorageContext } from "../storage/StorageContext";
import { BaseDocumentStore } from "../storage/docStore/types";
import { BaseIndexStore } from "../storage/indexStore/types";
import { VectorStore } from "../storage/vectorStore/types";
import { BaseSynthesizer } from "../synthesizers";

/**
* The underlying structure of each index.
Expand Down Expand Up @@ -180,7 +180,7 @@ export abstract class BaseIndex<T> {
*/
abstract asQueryEngine(options?: {
retriever?: BaseRetriever;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
}): BaseQueryEngine;

/**
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/indices/keyword/KeywordTableIndex.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import { BaseNode, Document, MetadataMode } from "../../Node";
import { defaultKeywordExtractPrompt } from "../../Prompt";
import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine";
import { ResponseSynthesizer } from "../../ResponseSynthesizer";
import { BaseRetriever } from "../../Retriever";
import {
ServiceContext,
serviceContextFromDefaults,
} from "../../ServiceContext";
import { StorageContext, storageContextFromDefaults } from "../../storage";
import { BaseDocumentStore } from "../../storage/docStore/types";
import { BaseSynthesizer } from "../../synthesizers";
import {
BaseIndex,
BaseIndexInit,
Expand Down Expand Up @@ -129,7 +129,7 @@ export class KeywordTableIndex extends BaseIndex<KeywordTable> {

asQueryEngine(options?: {
retriever?: BaseRetriever;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
preFilters?: unknown;
nodePostprocessors?: BaseNodePostprocessor[];
}): BaseQueryEngine {
Expand Down
13 changes: 7 additions & 6 deletions packages/core/src/indices/summary/SummaryIndex.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import _ from "lodash";
import { BaseNode, Document } from "../../Node";
import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine";
import {
CompactAndRefine,
ResponseSynthesizer,
} from "../../ResponseSynthesizer";
import { BaseRetriever } from "../../Retriever";
import {
ServiceContext,
serviceContextFromDefaults,
} from "../../ServiceContext";
import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
import {
StorageContext,
storageContextFromDefaults,
} from "../../storage/StorageContext";
import { BaseDocumentStore, RefDocInfo } from "../../storage/docStore/types";
import {
BaseSynthesizer,
CompactAndRefine,
ResponseSynthesizer,
} from "../../synthesizers";
import {
BaseIndex,
BaseIndexInit,
Expand Down Expand Up @@ -155,7 +156,7 @@ export class SummaryIndex extends BaseIndex<IndexList> {

asQueryEngine(options?: {
retriever?: BaseRetriever;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
preFilters?: unknown;
nodePostprocessors?: BaseNodePostprocessor[];
}): BaseQueryEngine {
Expand Down
28 changes: 4 additions & 24 deletions packages/core/src/indices/vectorStore/VectorStoreIndex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ import {
ImageNode,
MetadataMode,
ObjectType,
TextNode,
jsonToNode,
splitNodesByType,
} from "../../Node";
import { BaseQueryEngine, RetrieverQueryEngine } from "../../QueryEngine";
import { ResponseSynthesizer } from "../../ResponseSynthesizer";
import { BaseRetriever } from "../../Retriever";
import {
ServiceContext,
Expand All @@ -25,6 +24,7 @@ import {
} from "../../storage/StorageContext";
import { BaseIndexStore } from "../../storage/indexStore/types";
import { VectorStore } from "../../storage/vectorStore/types";
import { ResponseSynthesizer, BaseSynthesizer } from "../../synthesizers";
import {
BaseIndex,
BaseIndexInit,
Expand Down Expand Up @@ -248,7 +248,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {

asQueryEngine(options?: {
retriever?: BaseRetriever;
responseSynthesizer?: ResponseSynthesizer;
responseSynthesizer?: BaseSynthesizer;
preFilters?: unknown;
nodePostprocessors?: BaseNodePostprocessor[];
}): BaseQueryEngine {
Expand Down Expand Up @@ -290,7 +290,7 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {
if (!nodes || nodes.length === 0) {
return;
}
const { imageNodes, textNodes } = this.splitNodes(nodes);
const { imageNodes, textNodes } = splitNodesByType(nodes);
if (imageNodes.length > 0) {
if (!this.imageVectorStore) {
throw new Error("Cannot insert image nodes without image vector store");
Expand Down Expand Up @@ -368,24 +368,4 @@ export class VectorStoreIndex extends BaseIndex<IndexDict> {

return nodesWithEmbeddings;
}

private splitNodes(nodes: BaseNode[]): {
imageNodes: ImageNode[];
textNodes: TextNode[];
} {
let imageNodes: ImageNode[] = [];
let textNodes: TextNode[] = [];

for (let node of nodes) {
if (node instanceof ImageNode) {
imageNodes.push(node);
} else if (node instanceof TextNode) {
textNodes.push(node);
}
}
return {
imageNodes,
textNodes,
};
}
}
3 changes: 2 additions & 1 deletion packages/core/src/llm/LLM.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import {
import { getOpenAISession, OpenAISession } from "./openai";
import { getPortkeySession, PortkeySession } from "./portkey";
import { ReplicateSession } from "./replicate";
import { MessageContent } from "../ChatEngine";

export type MessageType =
| "user"
Expand Down Expand Up @@ -89,7 +90,7 @@ export interface LLM {
T extends boolean | undefined = undefined,
R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse,
>(
prompt: string,
prompt: MessageContent,
parentEvent?: Event,
streaming?: T,
): Promise<R>;
Expand Down
50 changes: 50 additions & 0 deletions packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import { MessageContentDetail } from "../ChatEngine";
import { MetadataMode, NodeWithScore, splitNodesByType } from "../Node";
import { Response } from "../Response";
import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext";
import { Event } from "../callbacks/CallbackManager";
import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt";
import { BaseSynthesizer } from "./types";

export class MultiModalResponseSynthesizer implements BaseSynthesizer {
serviceContext: ServiceContext;
metadataMode: MetadataMode;
textQATemplate: TextQaPrompt;

constructor({
serviceContext,
textQATemplate,
metadataMode,
}: Partial<MultiModalResponseSynthesizer> = {}) {
this.serviceContext = serviceContext ?? serviceContextFromDefaults();
this.metadataMode = metadataMode ?? MetadataMode.NONE;
this.textQATemplate = textQATemplate ?? defaultTextQaPrompt;
}

async synthesize(
query: string,
nodesWithScore: NodeWithScore[],
parentEvent?: Event,
): Promise<Response> {
const nodes = nodesWithScore.map(({ node }) => node);
const { imageNodes, textNodes } = splitNodesByType(nodes);
const textChunks = textNodes.map((node) =>
node.getContent(this.metadataMode),
);
// TODO: use builders to generate context
const context = textChunks.join("\n\n");
const textPrompt = this.textQATemplate({ context, query });
// TODO: get images from imageNodes
const prompt: MessageContentDetail[] = [
{ type: "text", text: textPrompt },
{
type: "image_url",
image_url: {
url: "https://upload.wikimedia.org/wikipedia/commons/b/b0/Vincent_van_Gogh_%281853-1890%29_Caf%C3%A9terras_bij_nacht_%28place_du_Forum%29_Kr%C3%B6ller-M%C3%BCller_Museum_Otterlo_23-8-2016_13-35-40.JPG",
},
},
];
let response = await this.serviceContext.llm.complete(prompt, parentEvent);
return new Response(response.message.content, nodes);
}
}
Loading

0 comments on commit f4b0961

Please sign in to comment.