Skip to content

feat: add run_async support for OllamaDocumentEmbedder #1878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import asyncio
from typing import Any, Dict, List, Optional

from haystack import Document, component
from tqdm import tqdm

from ollama import Client
from ollama import AsyncClient, Client


@component
Expand Down Expand Up @@ -74,6 +75,20 @@ def __init__(
self.prefix = prefix

self._client = Client(host=self.url, timeout=self.timeout)
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)

def _prepare_input(self, documents: List[Document]) -> List[Document]:
"""
Prepares the list of documents to embed by appropriate validation.
"""
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
msg = (
"OllamaDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
)
raise TypeError(msg)

return documents

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Expand Down Expand Up @@ -115,6 +130,35 @@ def _embed_batch(

return all_embeddings

async def _embed_batch_async(
self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None
):
"""
Internal method to embed a batch of texts asynchronously.
"""
all_embeddings = []

batches = [texts_to_embed[i : i + batch_size] for i in range(0, len(texts_to_embed), batch_size)]

tasks = [
self._async_client.embed(
model=self.model,
input=batch,
options=generation_kwargs,
)
for batch in batches
]

results = await asyncio.gather(*tasks, return_exceptions=True)

for idx, res in enumerate(results):
if isinstance(res, BaseException):
err_msg = f"Embedding batch {idx} raised an exception."
raise RuntimeError(err_msg)
all_embeddings.extend(res["embeddings"])

return all_embeddings

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Expand All @@ -130,12 +174,11 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
- `documents`: Documents with embedding information attached
- `meta`: The metadata collected during the embedding process
"""
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
msg = (
"OllamaDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a list of strings, please use the OllamaTextEmbedder."
)
raise TypeError(msg)
documents = self._prepare_input(documents=documents)

if not documents:
# return early if we were passed an empty list
return {"documents": [], "meta": {}}

generation_kwargs = generation_kwargs or self.generation_kwargs

Expand All @@ -148,3 +191,37 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
doc.embedding = emb

return {"documents": documents, "meta": {"model": self.model}}

@component.output_types(documents=List[Document], meta=Dict[str, Any])
async def run_async(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Asynchronously run an Ollama Model to compute embeddings of the provided documents.

:param documents:
Documents to be converted to an embedding.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:returns: A dictionary with the following keys:
- `documents`: Documents with embedding information attached
- `meta`: The metadata collected during the embedding process
"""

documents = self._prepare_input(documents=documents)

if not documents:
# return early if we were passed an empty list
return {"documents": [], "meta": {}}

generation_kwargs = generation_kwargs or self.generation_kwargs

texts_to_embed = self._prepare_texts_to_embed(documents=documents)
embeddings = await self._embed_batch_async(
texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs
)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents, "meta": {"model": self.model}}
17 changes: 17 additions & 0 deletions integrations/ollama/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ def test_run(self):
Document(content="Llamas have been used as pack animals for centuries, especially in South America."),
]
result = embedder.run(list_of_docs)

assert result["meta"]["model"] == "nomic-embed-text"
documents = result["documents"]
assert len(documents) == 3
assert all(isinstance(element, float) for document in documents for element in document.embedding)

@pytest.mark.asyncio
@pytest.mark.integration
async def test_run_async(self):
embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2)
list_of_docs = [
Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."),
Document(content="The Andes mountains are the natural habitat of many llamas."),
Document(content="Llamas have been used as pack animals for centuries, especially in South America."),
]
result = await embedder.run_async(list_of_docs)

assert result["meta"]["model"] == "nomic-embed-text"
documents = result["documents"]
assert len(documents) == 3
Expand Down