Skip to content

Commit 43dab20

Browse files
srini047Amnah199
andauthored
feat: add run_async support for OllamaDocumentEmbedder (#1878)
* feat: add run_async support for OllamaDocumentEmbedder * fix: use one client per doc embedder class obj * docs: address review comments * fix: better async handling * fix: type checking --------- Co-authored-by: Amna Mubashar <amnahkhan.ak@gmail.com>
1 parent 653aa24 commit 43dab20

File tree

2 files changed

+101
-7
lines changed

2 files changed

+101
-7
lines changed

integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import asyncio
12
from typing import Any, Dict, List, Optional
23

34
from haystack import Document, component
45
from tqdm import tqdm
56

6-
from ollama import Client
7+
from ollama import AsyncClient, Client
78

89

910
@component
@@ -74,6 +75,20 @@ def __init__(
7475
self.prefix = prefix
7576

7677
self._client = Client(host=self.url, timeout=self.timeout)
78+
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)
79+
80+
def _prepare_input(self, documents: List[Document]) -> List[Document]:
81+
"""
82+
Prepares the list of documents to embed by appropriate validation.
83+
"""
84+
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
85+
msg = (
86+
"OllamaDocumentEmbedder expects a list of Documents as input."
87+
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
88+
)
89+
raise TypeError(msg)
90+
91+
return documents
7792

7893
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
7994
"""
@@ -115,6 +130,35 @@ def _embed_batch(
115130

116131
return all_embeddings
117132

133+
async def _embed_batch_async(
134+
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
135+
):
136+
"""
137+
Internal method to embed a batch of texts asynchronously.
138+
"""
139+
all_embeddings = []
140+
141+
batches = [texts_to_embed[i : i + batch_size] for i in range(0, len(texts_to_embed), batch_size)]
142+
143+
tasks = [
144+
self._async_client.embed(
145+
model=self.model,
146+
input=batch,
147+
options=generation_kwargs,
148+
)
149+
for batch in batches
150+
]
151+
152+
results = await asyncio.gather(*tasks, return_exceptions=True)
153+
154+
for idx, res in enumerate(results):
155+
if isinstance(res, BaseException):
156+
err_msg = f"Embedding batch {idx} raised an exception."
157+
raise RuntimeError(err_msg)
158+
all_embeddings.extend(res["embeddings"])
159+
160+
return all_embeddings
161+
118162
@component.output_types(documents=List[Document], meta=Dict[str, Any])
119163
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
120164
"""
@@ -130,12 +174,11 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
130174
- `documents`: Documents with embedding information attached
131175
- `meta`: The metadata collected during the embedding process
132176
"""
133-
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
134-
msg = (
135-
"OllamaDocumentEmbedder expects a list of Documents as input."
136-
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
137-
)
138-
raise TypeError(msg)
177+
documents = self._prepare_input(documents=documents)
178+
179+
if not documents:
180+
# return early if we were passed an empty list
181+
return {"documents": [], "meta": {}}
139182

140183
generation_kwargs = generation_kwargs or self.generation_kwargs
141184

@@ -148,3 +191,37 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
148191
doc.embedding = emb
149192

150193
return {"documents": documents, "meta": {"model": self.model}}
194+
195+
@component.output_types(documents=List[Document], meta=Dict[str, Any])
196+
async def run_async(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
197+
"""
198+
Asynchronously run an Ollama Model to compute embeddings of the provided documents.
199+
200+
:param documents:
201+
Documents to be converted to an embedding.
202+
:param generation_kwargs:
203+
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
204+
top_p, etc. See the
205+
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
206+
:returns: A dictionary with the following keys:
207+
- `documents`: Documents with embedding information attached
208+
- `meta`: The metadata collected during the embedding process
209+
"""
210+
211+
documents = self._prepare_input(documents=documents)
212+
213+
if not documents:
214+
# return early if we were passed an empty list
215+
return {"documents": [], "meta": {}}
216+
217+
generation_kwargs = generation_kwargs or self.generation_kwargs
218+
219+
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
220+
embeddings = await self._embed_batch_async(
221+
texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs
222+
)
223+
224+
for doc, emb in zip(documents, embeddings):
225+
doc.embedding = emb
226+
227+
return {"documents": documents, "meta": {"model": self.model}}

integrations/ollama/tests/test_document_embedder.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,23 @@ def test_run(self):
5050
Document(content="Llamas have been used as pack animals for centuries, especially in South America."),
5151
]
5252
result = embedder.run(list_of_docs)
53+
54+
assert result["meta"]["model"] == "nomic-embed-text"
55+
documents = result["documents"]
56+
assert len(documents) == 3
57+
assert all(isinstance(element, float) for document in documents for element in document.embedding)
58+
59+
@pytest.mark.asyncio
60+
@pytest.mark.integration
61+
async def test_run_async(self):
62+
embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2)
63+
list_of_docs = [
64+
Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."),
65+
Document(content="The Andes mountains are the natural habitat of many llamas."),
66+
Document(content="Llamas have been used as pack animals for centuries, especially in South America."),
67+
]
68+
result = await embedder.run_async(list_of_docs)
69+
5370
assert result["meta"]["model"] == "nomic-embed-text"
5471
documents = result["documents"]
5572
assert len(documents) == 3

0 commit comments

Comments
 (0)