From ae39acae85c6bf04e02b8e0351c19e4b5823de9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Peter=20Stra=C3=9Fer?= Date: Sun, 2 Mar 2025 22:14:43 +0100 Subject: [PATCH] Follow up on part 1 for the Colpali blog. Goes into bit vectors, average vectors and token pooling to make late interaction vectors more scalable. --- notebooks/colpali/02_bit_vectors.ipynb | 394 ++++++++++++++++++++ notebooks/colpali/03_average_vector.ipynb | 421 ++++++++++++++++++++++ notebooks/colpali/04_token_pooling.ipynb | 320 ++++++++++++++++ 3 files changed, 1135 insertions(+) create mode 100644 notebooks/colpali/02_bit_vectors.ipynb create mode 100644 notebooks/colpali/03_average_vector.ipynb create mode 100644 notebooks/colpali/04_token_pooling.ipynb diff --git a/notebooks/colpali/02_bit_vectors.ipynb b/notebooks/colpali/02_bit_vectors.ipynb new file mode 100644 index 00000000..89540d26 --- /dev/null +++ b/notebooks/colpali/02_bit_vectors.ipynb @@ -0,0 +1,394 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e0056726-3359-4a9b-9913-016617525a6d", + "metadata": {}, + "source": [ + "# Scalable late interaction vectors in Elasticsearch: Bit Vectors #\n", + "\n", + "In this notebook, we will be looking at how to convert late interaction vectors to bit vectors to \n", + "1. Save siginificant disk space \n", + "2. Lower query latency\n", + " \n", + "We will also look at how we can use hamming distance to speed our queries up even further. \n", + "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n", + " \n", + "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. " + ] + }, + { + "cell_type": "markdown", + "id": "49dbcc61-5dab-4cf6-bbc5-7fa898707ce6", + "metadata": {}, + "source": [ + "This is the key part of this notebook. We use the `to_bit_vectors()` function to convert our vectors into bit vectors. \n", + "The function is simple in essence. Values `> 0` are converted to `1`, values `< 0` are converted to `0`. We then convert our array of `0`s and `1`s to a hex string, that represents our bit vector. \n", + "So don't be surprised that the values that we will be indexing look like strings and not arrays as before. This is intended! \n", + "\n", + "Learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def to_bit_vectors(embedding: list) -> list:\n", + " embeddings = []\n", + " for idx, patch_embedding in enumerate(embedding):\n", + " patch_embedding = np.array(patch_embedding)\n", + " binary_vector = (\n", + " np.packbits(np.where(patch_embedding > 0, 1, 0))\n", + " .astype(np.int8)\n", + " .tobytes()\n", + " .hex()\n", + " )\n", + " embeddings.append(binary_vector)\n", + " return embeddings" + ] + }, + { + "cell_type": "markdown", + "id": "52b7449b-8fbf-46b7-90c9-330070f6996a", + "metadata": {}, + "source": [ + "Here we are defining our mapping for our Elasticsearch index. Note how we set the `element_type` parameter to `bit` to inform Elasticsearch that we will be indexing bit vectors in this field. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] Index 'searchlabs-colpali-hamming' already exists.\n" + ] + } + ], + "source": [ + "import os\n", + "from dotenv import load_dotenv\n", + "from elasticsearch import Elasticsearch\n", + "\n", + "load_dotenv(\"elastic.env\")\n", + "\n", + "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n", + "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n", + "INDEX_NAME = \"searchlabs-colpali-hamming\"\n", + "\n", + "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n", + "\n", + "mappings = {\n", + " \"mappings\": {\n", + " \"properties\": {\n", + " \"col_pali_vectors\": {\n", + " \"type\": \"rank_vectors\",\n", + " \"element_type\": \"bit\"\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "if not es.indices.exists(index=INDEX_NAME):\n", + " print(f\"[INFO] Creating index: {INDEX_NAME}\")\n", + " es.indices.create(index=INDEX_NAME, body=mappings)\n", + "else:\n", + " print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n", + "\n", + "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n", + " for attempt in range(1, retries + 1):\n", + " try:\n", + " return es_client.index(index=index, id=doc_id, document=document)\n", + " except Exception as e:\n", + " if attempt < retries:\n", + " wait_time = initial_backoff * (2 ** (attempt - 1))\n", + " print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n", + " time.sleep(wait_time)\n", + " else:\n", + " print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n", + " raise" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "022b4af8891b4a06962e023c7f92d8f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Indexing documents: 0%| | 0/500 [00:00 list:\n", + " queries = col_pali_processor.process_queries([query]).to(model.device)\n", + " with torch.no_grad():\n", + " return model(**queries).tolist()[0]" + ] + }, + { + "cell_type": "raw", + "id": "5e86697d-d9dd-4224-85c8-023c71c88548", + "metadata": {}, + "source": [ + "Here we run the search against our index comparing our query vector converted to bit vectors to the bit vectors in our index. \n", + "Trading of a bit of accuracy, this is allows us to use hamming distance (`maxSimInvHamming(...)`), which is able to leverage optimzations such as bit-masks, SIMD, etc. Again - learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. \n", + "\n", + "See the cell below about a different technique to query our bit vectors. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\"_source\": false, \"query\": {\"script_score\": {\"query\": {\"match_all\": {}}, \"script\": {\"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\", \"params\": {\"query_vector\": [\"7747bcd9732859c3645aa81036f5c960\", \"729b3c418ba8594a67daa042eca1c961\", \"609e3d8a2ac379c2204aa0cfa8345bdc\", \"30bf378a2ac279da245aa8dfa83c3bdc\", \"64af77ea2acdf9c28c0aa5df863677f4\", \"686f3fce2ac871c26e6aaddf023455ec\", \"383f31a8e8c0f8ca2c4ab54f047c7dec\", \"203b33caaac279da0acaa54f8a3c6bcc\", \"319a63eba8d279ca30dbbccf8f757b8e\", \"203b73ca28d2798a325bb44f8c3c5bce\", \"203bb7caa8d2718a1a4bb14f8a3c5bdc\", \"203bb7caa8d2798a1a6aa14f8a3c5fdc\", \"303b33caa8d2798a0a4aa14f8a3c5bdc\", \"303b33caaad379ca0e4aa14f8a3c5bdc\", \"709b33caaac379ca0c4aa14f8a3c5fdc\", \"708e37eaaac779ca2c4aa1df863c1fdc\", \"648e77ea6acd79caac4ae1df86363ffc\", \"648e77ea6acdf9caac4ae5df06363ffc\", \"608f37ea2ac579ca2c4ea1df063c3ffc\", \"709f37c8aac379ca2c4ea1df863c1fdc\", \"70af31c82ac671ce2c6ab14fc43c1bfc\"]}}}}, \"size\": 5}\n" + ] + }, + { + "data": { + "text/html": [ + "
\"image_104.jpg\"\"image_3.jpg\"\"image_12.jpg\"\"image_2.jpg\"\"image_92.jpg\"
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display, HTML\n", + "import os\n", + "import json\n", + "\n", + "DOCUMENT_DIR = \"searchlabs-colpali\"\n", + "\n", + "query = \"What do companies use for recruiting?\"\n", + "query_vector = to_bit_vectors(create_col_pali_query_vectors(query))\n", + "es_query = {\n", + " \"_source\": False,\n", + " \"query\": {\n", + " \"script_score\": {\n", + " \"query\": {\n", + " \"match_all\": {}\n", + " },\n", + " \"script\": {\n", + " \"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\",\n", + " \"params\": {\n", + " \"query_vector\": query_vector\n", + " }\n", + " }\n", + " }\n", + " },\n", + " \"size\": 5\n", + "}\n", + "print(json.dumps(es_query))\n", + "\n", + "results = es.search(index=INDEX_NAME, body=es_query)\n", + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", + "\n", + "html = \"
\"\n", + "for image_id in image_ids:\n", + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", + " html += f'\"{image_id}\"'\n", + "html += \"
\"\n", + "\n", + "display(HTML(html))" + ] + }, + { + "cell_type": "markdown", + "id": "e27b68ac-bec8-4415-919e-8b916bc35816", + "metadata": {}, + "source": [ + "Above we have seen how to query our data using the `maxSimInvHamming(...)` function. \n", + "We can also just pass the full fidelity col pali vector and use the `maxSimDotProduct(...)` function for [asymmetric similarity](https://www.elastic.co/guide/en/elasticsearch/reference/8.18/rank-vectors.html#rank-vectors-scoring) between the vectors. " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "32fd9ee4-d7c6-4954-a766-7b06735290ff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\"image_104.jpg\"\"image_3.jpg\"\"image_2.jpg\"\"image_12.jpg\"\"image_92.jpg\"
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "query = \"What do companies use for recruiting?\"\n", + "query_vector = create_col_pali_query_vectors(query)\n", + "es_query = {\n", + " \"_source\": False,\n", + " \"query\": {\n", + " \"script_score\": {\n", + " \"query\": {\n", + " \"match_all\": {}\n", + " },\n", + " \"script\": {\n", + " \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n", + " \"params\": {\n", + " \"query_vector\": query_vector\n", + " }\n", + " }\n", + " }\n", + " },\n", + " \"size\": 5\n", + "}\n", + "\n", + "results = es.search(index=INDEX_NAME, body=es_query)\n", + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", + "\n", + "html = \"
\"\n", + "for image_id in image_ids:\n", + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", + " html += f'\"{image_id}\"'\n", + "html += \"
\"\n", + "\n", + "display(HTML(html))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee8df1e3-af66-4e35-9c26-7257c281536f", + "metadata": {}, + "outputs": [], + "source": [ + "# We kill the kernel forcefully to free up the memory from the ColPali model. \n", + "print(\"Shutting down the kernel to free memory...\")\n", + "import os\n", + "os._exit(0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dependecy-test-colpali-blog", + "language": "python", + "name": "dependecy-test-colpali-blog" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/colpali/03_average_vector.ipynb b/notebooks/colpali/03_average_vector.ipynb new file mode 100644 index 00000000..0a6b3058 --- /dev/null +++ b/notebooks/colpali/03_average_vector.ipynb @@ -0,0 +1,421 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5d4c33a1-a009-4d41-aa04-f100d012f6ce", + "metadata": {}, + "source": [ + "# Scalable late interaction vectors in Elasticsearch: Average Vectors #\n", + "\n", + "In this notebook, we will be looking at how scale search with late interaction models. We will be taking the average vector over our late interaction multi-vectors and use Elasticsearchs vector search capabilities to achieve scalable search over billions of vectors. \n", + " \n", + "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n", + "\n", + "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. " + ] + }, + { + "cell_type": "markdown", + "id": "81fec537-e52b-4e03-967f-b227a5349cae", + "metadata": {}, + "source": [ + "This is the key part of this notebook. We use `to_avg_vector(vectors)` to convert our 2d array into a single vector that holds the \"average meaning\" of the document. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def to_bit_vectors(embedding: list) -> list:\n", + " embeddings = []\n", + " for idx, patch_embedding in enumerate(embedding):\n", + " patch_embedding = np.array(patch_embedding)\n", + " binary_vector = (\n", + " np.packbits(np.where(patch_embedding > 0, 1, 0))\n", + " .astype(np.int8)\n", + " .tobytes()\n", + " .hex()\n", + " )\n", + " embeddings.append(binary_vector)\n", + " return embeddings\n", + "\n", + "def to_avg_vector(vectors):\n", + " vectors_array = np.array(vectors)\n", + " \n", + " avg_vector = np.mean(vectors_array, axis=0)\n", + " \n", + " norm = np.linalg.norm(avg_vector)\n", + " if norm > 0:\n", + " normalized_avg_vector = avg_vector / norm\n", + " else:\n", + " normalized_avg_vector = avg_vector\n", + "\n", + " return normalized_avg_vector.tolist()" + ] + }, + { + "cell_type": "markdown", + "id": "80d0bb03-36e9-4050-83c9-0cb54486842b", + "metadata": {}, + "source": [ + "Here we are defining our mapping for our Elasticsearch index. Note how we are using the `dense_vector` field type for our average vector. \n", + "This allows us to leverage the highly optimized HNSW indexing structures in Elasticsearch with which we can scale our search to billions of documents. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] Index 'searchlabs-colpali-average-vector' already exists.\n" + ] + } + ], + "source": [ + "import os\n", + "from dotenv import load_dotenv\n", + "from elasticsearch import Elasticsearch\n", + "\n", + "load_dotenv(\"elastic.env\")\n", + "\n", + "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n", + "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n", + "INDEX_NAME = \"searchlabs-colpali-average-vector\"\n", + "\n", + "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n", + "\n", + "mappings = {\n", + " \"mappings\": {\n", + " \"properties\": {\n", + " \"avg_vector\": {\n", + " \"type\": \"dense_vector\",\n", + " \"dims\": 128,\n", + " \"index\": True,\n", + " \"similarity\": \"dot_product\"\n", + " },\n", + " \"col_pali_vectors\": {\n", + " \"type\": \"rank_vectors\",\n", + " \"element_type\": \"bit\"\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "if not es.indices.exists(index=INDEX_NAME):\n", + " print(f\"[INFO] Creating index: {INDEX_NAME}\")\n", + " es.indices.create(index=INDEX_NAME, body=mappings)\n", + "else:\n", + " print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n", + "\n", + "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n", + " for attempt in range(1, retries + 1):\n", + " try:\n", + " return es_client.index(index=index, id=doc_id, document=document)\n", + " except Exception as e:\n", + " if attempt < retries:\n", + " wait_time = initial_backoff * (2 ** (attempt - 1))\n", + " print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n", + " time.sleep(wait_time)\n", + " else:\n", + " print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n", + " raise" + ] + }, + { + "cell_type": "markdown", + "id": "662569b6-77d0-4a0d-8ad6-5d698cf0404c", + "metadata": {}, + "source": [ + "Here we are looping over all our vectors and convert them into our average vectors. \n", + "We still save our full fidelity ColPali vectors as bit vectors as we can use them for reranking. More on that later. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4a9d08424a504956a4f50208a19cce90", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Indexing documents: 0%| | 0/500 [00:00 list:\n", + " queries = col_pali_processor.process_queries([query]).to(model.device)\n", + " with torch.no_grad():\n", + " return model(**queries).tolist()[0]" + ] + }, + { + "cell_type": "markdown", + "id": "ae89aa6f-7022-4ee7-b930-7e1e94c9abfc", + "metadata": {}, + "source": [ + "Again we create our query vector by using the ColPali model and calculating the average vector. We can now use Elasticsearches `knn` query to compare our single vector to the average vectors within our index. \n", + "Notice that the document that we have found in our previous examples with the title *Social Media for Recruitment* is now in 5th position. See the next cell on how we can handle this. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\"image_12.jpg\"\"image_3.jpg\"\"image_49.jpg\"\"image_123.jpg\"\"image_104.jpg\"
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display, HTML\n", + "import os\n", + "import json\n", + "\n", + "DOCUMENT_DIR = \"searchlabs-colpali\"\n", + "\n", + "query = \"What do companies use for recruiting?\"\n", + "query_vector = to_avg_vector(create_col_pali_query_vectors(query))\n", + "es_query = {\n", + " \"_source\": False,\n", + " \"knn\": {\n", + " \"field\": \"avg_vector\",\n", + " \"query_vector\": query_vector,\n", + " \"k\": 10,\n", + " \"num_candidates\": 100\n", + " },\n", + " \"size\": 5\n", + "}\n", + "\n", + "results = es.search(index=INDEX_NAME, body=es_query)\n", + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", + "\n", + "html = \"
\"\n", + "for image_id in image_ids:\n", + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", + " html += f'\"{image_id}\"'\n", + "html += \"
\"\n", + "\n", + "display(HTML(html))" + ] + }, + { + "cell_type": "markdown", + "id": "3cf39408-c1b3-4832-a2a6-4c53fb4575d7", + "metadata": {}, + "source": [ + "In the cell above we have seen that the document *Social Media for Recruitment* is not considered the most relevant anymore. \n", + "\n", + "What we want to do to handle this is to run our KNN vector stage as a **first stage retrieval** algorithm. This is the `knn retriever` in the example below. \n", + "We then run a **first stage retrieval** rescoring algorithm on a smaller set of these documents. For this we can use higher fidelity ColPali vectors. This is the `rescore` section in the example below. " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b04f7fb8-a787-42bc-90a5-c20282cefc0c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\"image_104.jpg\"\"image_3.jpg\"\"image_2.jpg\"\"image_12.jpg\"\"image_49.jpg\"
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "query = \"What do companies use for recruiting?\"\n", + "col_pali_vector = create_col_pali_query_vectors(query)\n", + "avg_vector = to_avg_vector(col_pali_vector)\n", + "es_query = {\n", + " \"_source\": False,\n", + " \"retriever\": {\n", + " \"rescorer\": {\n", + " \"retriever\": {\n", + " \"knn\": {\n", + " \"field\": \"avg_vector\",\n", + " \"query_vector\": avg_vector,\n", + " \"k\": 10,\n", + " \"num_candidates\": 100\n", + " }\n", + " },\n", + " \"rescore\": {\n", + " \"window_size\": 10,\n", + " \"query\": {\n", + " \"rescore_query\": {\n", + " \"script_score\": {\n", + " \"query\": {\n", + " \"match_all\": {}\n", + " },\n", + " \"script\": {\n", + " \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n", + " \"params\": {\n", + " \"query_vector\": col_pali_vector\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " }\n", + " },\n", + " \"size\": 5\n", + "}\n", + "\n", + "results = es.search(index=INDEX_NAME, body=es_query)\n", + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", + "\n", + "html = \"
\"\n", + "for image_id in image_ids:\n", + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", + " html += f'\"{image_id}\"'\n", + "html += \"
\"\n", + "\n", + "display(HTML(html))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5347aef4-8ac0-48ec-a2cb-10745ec1f487", + "metadata": {}, + "outputs": [], + "source": [ + "# We kill the kernel forcefully to free up the memory from the ColPali model. \n", + "print(\"Shutting down the kernel to free memory...\")\n", + "import os\n", + "os._exit(0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dependecy-test-colpali-blog", + "language": "python", + "name": "dependecy-test-colpali-blog" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/colpali/04_token_pooling.ipynb b/notebooks/colpali/04_token_pooling.ipynb new file mode 100644 index 00000000..a5101b86 --- /dev/null +++ b/notebooks/colpali/04_token_pooling.ipynb @@ -0,0 +1,320 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4998ae3c-74a9-4fee-812f-2fc2513c3915", + "metadata": {}, + "source": [ + "# Scalable late interaction vectors in Elasticsearch: Token Pooling #\n", + "\n", + "In this notebook, we will be looking at how scale search with late interaction models. We will be looking a token pooling - a technique to reduce the dimensionality of the late interaction multi-vectors by clustering similar information. This technique can of course be combined with the other techniques we have discussed in the previous notebooks. \n", + "\n", + "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n", + "\n", + "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def to_bit_vectors(embedding: list) -> list:\n", + " embeddings = []\n", + " for idx, patch_embedding in enumerate(embedding):\n", + " patch_embedding = np.array(patch_embedding)\n", + " binary_vector = (\n", + " np.packbits(np.where(patch_embedding > 0, 1, 0))\n", + " .astype(np.int8)\n", + " .tobytes()\n", + " .hex()\n", + " )\n", + " embeddings.append(binary_vector)\n", + " return embeddings" + ] + }, + { + "cell_type": "markdown", + "id": "7e0887ca-f194-429b-9afa-208d047a75e4", + "metadata": {}, + "source": [ + "We will be using the `HierarchicalTokenPooler` from the [colpali-engine](https://github.com/illuin-tech/colpali?tab=readme-ov-file#token-pooling) to reduce the dimensions of our vector. \n", + "The authors recommend a `pool_factor=3` for most cases, but you should always tests how it impact the relevancy of your dataset. " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9871c9c5-c923-4deb-9f5b-aa6796ba0bbf", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from colpali_engine.compression.token_pooling import HierarchicalTokenPooler\n", + "\n", + "pooler = HierarchicalTokenPooler(pool_factor=3) # test on your data for a good pool_factor\n", + "\n", + "def pool_vectors(embedding: list) -> list:\n", + " tensor = torch.tensor(embedding).unsqueeze(0)\n", + " pooled = pooler.pool_embeddings(tensor)\n", + " return pooled.squeeze(0).tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] Index 'searchlabs-colpali-token-pooling' already exists.\n" + ] + } + ], + "source": [ + "import os\n", + "from dotenv import load_dotenv\n", + "from elasticsearch import Elasticsearch\n", + "\n", + "load_dotenv(\"elastic.env\")\n", + "\n", + "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n", + "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n", + "INDEX_NAME = \"searchlabs-colpali-token-pooling\"\n", + "\n", + "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n", + "\n", + "mappings = {\n", + " \"mappings\": {\n", + " \"properties\": {\n", + " \"pooled_col_pali_vectors\": {\n", + " \"type\": \"rank_vectors\",\n", + " \"element_type\": \"bit\"\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "if not es.indices.exists(index=INDEX_NAME):\n", + " print(f\"[INFO] Creating index: {INDEX_NAME}\")\n", + " es.indices.create(index=INDEX_NAME, body=mappings)\n", + "else:\n", + " print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n", + "\n", + "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n", + " for attempt in range(1, retries + 1):\n", + " try:\n", + " return es_client.index(index=index, id=doc_id, document=document)\n", + " except Exception as e:\n", + " if attempt < retries:\n", + " wait_time = initial_backoff * (2 ** (attempt - 1))\n", + " print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n", + " time.sleep(wait_time)\n", + " else:\n", + " print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n", + " raise" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cef0c48b9b5d4b3982fbdb4773494ec8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Indexing documents: 0%| | 0/500 [00:00 list:\n", + " queries = col_pali_processor.process_queries([query]).to(model.device)\n", + " with torch.no_grad():\n", + " return model(**queries).tolist()[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\"image_3.jpg\"\"image_104.jpg\"\"image_2.jpg\"\"image_12.jpg\"\"image_120.jpg\"
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import display, HTML\n", + "import os\n", + "import json\n", + "\n", + "DOCUMENT_DIR = \"searchlabs-colpali\"\n", + "\n", + "query = \"What do companies use for recruiting?\"\n", + "query_vector = create_col_pali_query_vectors(query)\n", + "es_query = {\n", + " \"_source\": False,\n", + " \"query\": {\n", + " \"script_score\": {\n", + " \"query\": {\n", + " \"match_all\": {}\n", + " },\n", + " \"script\": {\n", + " \"source\": \"maxSimDotProduct(params.query_vector, 'pooled_col_pali_vectors')\",\n", + " \"params\": {\n", + " \"query_vector\": query_vector\n", + " }\n", + " }\n", + " }\n", + " },\n", + " \"size\": 5\n", + "}\n", + "\n", + "results = es.search(index=INDEX_NAME, body=es_query)\n", + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", + "\n", + "html = \"
\"\n", + "for image_id in image_ids:\n", + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", + " html += f'\"{image_id}\"'\n", + "html += \"
\"\n", + "\n", + "display(HTML(html))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32fd9ee4-d7c6-4954-a766-7b06735290ff", + "metadata": {}, + "outputs": [], + "source": [ + "# We kill the kernel forcefully to free up the memory from the ColPali model. \n", + "print(\"Shutting down the kernel to free memory...\")\n", + "import os\n", + "os._exit(0)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dependecy-test-colpali-blog", + "language": "python", + "name": "dependecy-test-colpali-blog" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}