Skip to content

Commit

Permalink
Moving to supporting blog content / formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-strsr committed Mar 5, 2025
1 parent 65ece4e commit c57a507
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 32 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"source": [
"!pip install -r requirements.txt\n",
"from IPython.display import clear_output\n",
"clear_output() # for less space usage. "
"\n",
"clear_output() # for less space usage."
]
},
{
Expand Down Expand Up @@ -64,15 +65,15 @@
],
"source": [
"from datasets import load_dataset\n",
"from tqdm.notebook import tqdm \n",
"from tqdm.notebook import tqdm\n",
"import os\n",
"\n",
"DATASET_NAME = \"vidore/infovqa_test_subsampled\"\n",
"DOCUMENT_DIR = \"searchlabs-colpali\"\n",
"\n",
"os.makedirs(DOCUMENT_DIR, exist_ok=True)\n",
"dataset = load_dataset(DATASET_NAME, split=\"test\")\n",
" \n",
"\n",
"for i, row in enumerate(tqdm(dataset, desc=\"Saving images to disk\")):\n",
" image = row.get(\"image\")\n",
" image_name = f\"image_{i}.jpg\"\n",
Expand Down Expand Up @@ -123,17 +124,21 @@
"model = ColPali.from_pretrained(\n",
" \"vidore/colpali-v1.3\",\n",
" torch_dtype=torch.float32,\n",
" device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
" device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
").eval()\n",
"\n",
"col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
"\n",
"\n",
"def create_col_pali_image_vectors(image_path: str) -> list:\n",
" batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(model.device)\n",
" \n",
" batch_images = col_pali_processor.process_images([Image.open(image_path)]).to(\n",
" model.device\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" return model(**batch_images).tolist()[0]\n",
"\n",
"\n",
"def create_col_pali_query_vectors(query: str) -> list:\n",
" queries = col_pali_processor.process_queries([query]).to(model.device)\n",
" with torch.no_grad():\n",
Expand Down Expand Up @@ -194,9 +199,9 @@
" vectors_f32 = create_col_pali_image_vectors(image_path)\n",
" file_to_multi_vectors[file_name] = vectors_f32\n",
"\n",
"with open('col_pali_vectors.pkl', 'wb') as f:\n",
"with open(\"col_pali_vectors.pkl\", \"wb\") as f:\n",
" pickle.dump(file_to_multi_vectors, f)\n",
" \n",
"\n",
"print(f\"Saved {len(file_to_multi_vectors)} vector entries to disk\")"
]
},
Expand Down Expand Up @@ -239,22 +244,15 @@
"\n",
"es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
"\n",
"mappings = {\n",
" \"mappings\": {\n",
" \"properties\": {\n",
" \"col_pali_vectors\": {\n",
" \"type\": \"rank_vectors\"\n",
" }\n",
" }\n",
" }\n",
"}\n",
"mappings = {\"mappings\": {\"properties\": {\"col_pali_vectors\": {\"type\": \"rank_vectors\"}}}}\n",
"\n",
"if not es.indices.exists(index=INDEX_NAME):\n",
" print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
" es.indices.create(index=INDEX_NAME, body=mappings)\n",
"else:\n",
" print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
"\n",
"\n",
"def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
" for attempt in range(1, retries + 1):\n",
" try:\n",
Expand Down Expand Up @@ -304,18 +302,18 @@
}
],
"source": [
"with open('col_pali_vectors.pkl', 'rb') as f:\n",
"with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
" file_to_multi_vectors = pickle.load(f)\n",
"\n",
"for file_name, vectors in tqdm(file_to_multi_vectors.items(), desc=\"Index documents\"):\n",
" if es.exists(index=INDEX_NAME, id=file_name):\n",
" continue\n",
" \n",
"\n",
" index_document(\n",
" es_client=es, \n",
" index=INDEX_NAME, \n",
" doc_id=file_name, \n",
" document={\"col_pali_vectors\": vectors}\n",
" es_client=es,\n",
" index=INDEX_NAME,\n",
" doc_id=file_name,\n",
" document={\"col_pali_vectors\": vectors},\n",
" )"
]
},
Expand Down Expand Up @@ -360,18 +358,14 @@
" \"_source\": False,\n",
" \"query\": {\n",
" \"script_score\": {\n",
" \"query\": {\n",
" \"match_all\": {}\n",
" },\n",
" \"query\": {\"match_all\": {}},\n",
" \"script\": {\n",
" \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
" \"params\": {\n",
" \"query_vector\": create_col_pali_query_vectors(query)\n",
" }\n",
" }\n",
" \"params\": {\"query_vector\": create_col_pali_query_vectors(query)},\n",
" },\n",
" }\n",
" },\n",
" \"size\": 5\n",
" \"size\": 5,\n",
"}\n",
"\n",
"results = es.search(index=INDEX_NAME, body=es_query)\n",
Expand All @@ -393,9 +387,10 @@
"metadata": {},
"outputs": [],
"source": [
"# We kill the kernel forcefully to free up the memory from the ColPali model. \n",
"# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
"print(\"Shutting down the kernel to free memory...\")\n",
"import os\n",
"\n",
"os._exit(0)"
]
}
Expand Down
File renamed without changes.

0 comments on commit c57a507

Please sign in to comment.