Skip to content

Commit

Permalink
feat: ensure retriever returns an image and send it to the LLM base64…
Browse files Browse the repository at this point in the history
… encoded
  • Loading branch information
marcusschiesser committed Dec 15, 2023
1 parent 399e394 commit 04730e3
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 340 deletions.
Binary file removed examples/multimodal/data/1.jpg
Binary file not shown.
Binary file removed examples/multimodal/data/2.jpg
Binary file not shown.
Binary file removed examples/multimodal/data/3.jpg
Binary file not shown.
323 changes: 0 additions & 323 deletions examples/multimodal/data/San Francisco.txt

This file was deleted.

4 changes: 2 additions & 2 deletions examples/multimodal/rag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ async function main() {

const queryEngine = index.asQueryEngine({
responseSynthesizer: new MultiModalResponseSynthesizer({ serviceContext }),
// TODO: set imageSimilarityTopK: 1,
retriever: index.asRetriever({ similarityTopK: 2 }),
// TODO: set text similarity to a higher value than image similarity
retriever: index.asRetriever({ similarityTopK: 1 }),
});
const result = await queryEngine.query(
"what are Vincent van Gogh's famous paintings",
Expand Down
3 changes: 1 addition & 2 deletions examples/multimodal/retrieve.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import {
TextNode,
VectorStoreIndex,
} from "llamaindex";
import * as path from "path";

export async function createIndex() {
// set up vector store index with two vector stores, one for text, the other for images
Expand Down Expand Up @@ -37,7 +36,7 @@ async function main() {
continue;
}
if (node instanceof ImageNode) {
console.log(`Image: ${path.join(__dirname, node.id_)}`);
console.log(`Image: ${node.getUrl()}`);
} else if (node instanceof TextNode) {
console.log("Text:", (node as TextNode).text.substring(0, 128));
}
Expand Down
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@xenova/transformers": "^2.10.0",
"assemblyai": "^4.0.0",
"crypto-js": "^4.2.0",
"file-type": "^18.7.0",
"js-tiktoken": "^1.0.8",
"lodash": "^4.17.21",
"mammoth": "^1.6.0",
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/Node.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import CryptoJS from "crypto-js";
import path from "path";
import { v4 as uuidv4 } from "uuid";

export enum NodeRelationship {
Expand Down Expand Up @@ -304,6 +305,12 @@ export class ImageNode<T extends Metadata = Metadata> extends TextNode<T> {
getType(): ObjectType {
return ObjectType.IMAGE;
}

getUrl(): URL {
// id_ stores the relative path, convert it to the URL of the file
const absPath = path.resolve(this.id_);
return new URL(`file://${absPath}`);
}
}

export class ImageDocument<T extends Metadata = Metadata> extends ImageNode<T> {
Expand Down
62 changes: 61 additions & 1 deletion packages/core/src/embeddings/utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import _ from "lodash";
import { ImageType } from "../Node";
import { DEFAULT_SIMILARITY_TOP_K } from "../constants";
import { VectorStoreQueryMode } from "../storage";
import { DEFAULT_FS, VectorStoreQueryMode } from "../storage";
import { SimilarityType } from "./types";

/**
Expand Down Expand Up @@ -185,6 +185,16 @@ export function getTopKMMREmbeddings(
return [resultSimilarities, resultIds];
}

async function blobToDataUrl(input: Blob) {
const { fileTypeFromBuffer } = await import("file-type");
const buffer = Buffer.from(await input.arrayBuffer());
const type = await fileTypeFromBuffer(buffer);
if (!type) {
throw new Error("Unsupported image type");
}
return "data:" + type.mime + ";base64," + buffer.toString("base64");
}

export async function readImage(input: ImageType) {
const { RawImage } = await import("@xenova/transformers");
if (input instanceof Blob) {
Expand All @@ -195,3 +205,53 @@ export async function readImage(input: ImageType) {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export async function imageToString(input: ImageType): Promise<string> {
if (input instanceof Blob) {
// if the image is a Blob, convert it to a base64 data URL
return await blobToDataUrl(input);
} else if (_.isString(input)) {
return input;
} else if (input instanceof URL) {
return input.toString();
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export function stringToImage(input: string): ImageType {
if (input.startsWith("data:")) {
// if the input is a base64 data URL, convert it back to a Blob
const base64Data = input.split(",")[1];
const byteArray = Buffer.from(base64Data, "base64");
return new Blob([byteArray]);
} else if (input.startsWith("http://") || input.startsWith("https://")) {
return new URL(input);
} else if (_.isString(input)) {
return input;
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}

export async function imageToDataUrl(input: ImageType): Promise<string> {
// first ensure, that the input is a Blob
if (
(input instanceof URL && input.protocol === "file:") ||
_.isString(input)
) {
// string or file URL
const fs = DEFAULT_FS;
const dataBuffer = await fs.readFile(
input instanceof URL ? input.pathname : input,
);
input = new Blob([dataBuffer]);
} else if (!(input instanceof Blob)) {
if (input instanceof URL) {
throw new Error(`Unsupported URL with protocol: ${input.protocol}`);
} else {
throw new Error(`Unsupported input type: ${typeof input}`);
}
}
return await blobToDataUrl(input);
}
14 changes: 10 additions & 4 deletions packages/core/src/indices/vectorStore/VectorIndexRetriever.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { Event } from "../../callbacks/CallbackManager";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
import { BaseEmbedding } from "../../embeddings";
import { globalsHelper } from "../../GlobalsHelper";
import { Metadata, NodeWithScore } from "../../Node";
import { ImageNode, Metadata, NodeWithScore } from "../../Node";
import { BaseRetriever } from "../../Retriever";
import { ServiceContext } from "../../ServiceContext";
import { Event } from "../../callbacks/CallbackManager";
import { DEFAULT_SIMILARITY_TOP_K } from "../../constants";
import { BaseEmbedding } from "../../embeddings";
import {
VectorStoreQuery,
VectorStoreQueryMode,
Expand Down Expand Up @@ -108,6 +108,12 @@ export class VectorIndexRetriever implements BaseRetriever {
}

const node = this.index.indexStruct.nodesDict[result.ids[i]];
// XXX: Hack, if it's an image node, we reconstruct the image from the URL
// Alternative: Store image in doc store and retrieve it here
if (node instanceof ImageNode) {
node.image = node.getUrl();
}

nodesWithScores.push({
node: node,
score: result.similarities[i],
Expand Down
26 changes: 18 additions & 8 deletions packages/core/src/synthesizers/MultiModalResponseSynthesizer.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import { MessageContentDetail } from "../ChatEngine";
import { MetadataMode, NodeWithScore, splitNodesByType } from "../Node";
import {
ImageNode,
MetadataMode,
NodeWithScore,
splitNodesByType,
} from "../Node";
import { Response } from "../Response";
import { ServiceContext, serviceContextFromDefaults } from "../ServiceContext";
import { Event } from "../callbacks/CallbackManager";
import { imageToDataUrl } from "../embeddings";
import { TextQaPrompt, defaultTextQaPrompt } from "./../Prompt";
import { BaseSynthesizer } from "./types";

Expand Down Expand Up @@ -34,15 +40,19 @@ export class MultiModalResponseSynthesizer implements BaseSynthesizer {
// TODO: use builders to generate context
const context = textChunks.join("\n\n");
const textPrompt = this.textQATemplate({ context, query });
// TODO: get images from imageNodes
const images = await Promise.all(
imageNodes.map(async (node: ImageNode) => {
return {
type: "image_url",
image_url: {
url: await imageToDataUrl(node.image),
},
} as MessageContentDetail;
}),
);
const prompt: MessageContentDetail[] = [
{ type: "text", text: textPrompt },
{
type: "image_url",
image_url: {
url: "https://upload.wikimedia.org/wikipedia/commons/b/b0/Vincent_van_Gogh_%281853-1890%29_Caf%C3%A9terras_bij_nacht_%28place_du_Forum%29_Kr%C3%B6ller-M%C3%BCller_Museum_Otterlo_23-8-2016_13-35-40.JPG",
},
},
...images,
];
let response = await this.serviceContext.llm.complete(prompt, parentEvent);
return new Response(response.message.content, nodes);
Expand Down
44 changes: 44 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 04730e3

Please sign in to comment.