1
+ import asyncio
1
2
from typing import Any , Dict , List , Optional
2
3
3
4
from haystack import Document , component
4
5
from tqdm import tqdm
5
6
6
- from ollama import Client
7
+ from ollama import AsyncClient , Client
7
8
8
9
9
10
@component
@@ -74,6 +75,20 @@ def __init__(
74
75
self .prefix = prefix
75
76
76
77
self ._client = Client (host = self .url , timeout = self .timeout )
78
+ self ._async_client = AsyncClient (host = self .url , timeout = self .timeout )
79
+
80
+ def _prepare_input (self , documents : List [Document ]) -> List [Document ]:
81
+ """
82
+ Prepares the list of documents to embed by appropriate validation.
83
+ """
84
+ if not isinstance (documents , list ) or (documents and not isinstance (documents [0 ], Document )):
85
+ msg = (
86
+ "OllamaDocumentEmbedder expects a list of Documents as input."
87
+ "In case you want to embed a list of strings, please use the OllamaTextEmbedder."
88
+ )
89
+ raise TypeError (msg )
90
+
91
+ return documents
77
92
78
93
def _prepare_texts_to_embed (self , documents : List [Document ]) -> List [str ]:
79
94
"""
@@ -115,6 +130,35 @@ def _embed_batch(
115
130
116
131
return all_embeddings
117
132
133
+ async def _embed_batch_async (
134
+ self , texts_to_embed : List [str ], batch_size : int , generation_kwargs : Optional [Dict [str , Any ]] = None
135
+ ):
136
+ """
137
+ Internal method to embed a batch of texts asynchronously.
138
+ """
139
+ all_embeddings = []
140
+
141
+ batches = [texts_to_embed [i : i + batch_size ] for i in range (0 , len (texts_to_embed ), batch_size )]
142
+
143
+ tasks = [
144
+ self ._async_client .embed (
145
+ model = self .model ,
146
+ input = batch ,
147
+ options = generation_kwargs ,
148
+ )
149
+ for batch in batches
150
+ ]
151
+
152
+ results = await asyncio .gather (* tasks , return_exceptions = True )
153
+
154
+ for idx , res in enumerate (results ):
155
+ if isinstance (res , BaseException ):
156
+ err_msg = f"Embedding batch { idx } raised an exception."
157
+ raise RuntimeError (err_msg )
158
+ all_embeddings .extend (res ["embeddings" ])
159
+
160
+ return all_embeddings
161
+
118
162
@component .output_types (documents = List [Document ], meta = Dict [str , Any ])
119
163
def run (self , documents : List [Document ], generation_kwargs : Optional [Dict [str , Any ]] = None ):
120
164
"""
@@ -130,12 +174,11 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
130
174
- `documents`: Documents with embedding information attached
131
175
- `meta`: The metadata collected during the embedding process
132
176
"""
133
- if not isinstance (documents , list ) or (documents and not isinstance (documents [0 ], Document )):
134
- msg = (
135
- "OllamaDocumentEmbedder expects a list of Documents as input."
136
- "In case you want to embed a list of strings, please use the OllamaTextEmbedder."
137
- )
138
- raise TypeError (msg )
177
+ documents = self ._prepare_input (documents = documents )
178
+
179
+ if not documents :
180
+ # return early if we were passed an empty list
181
+ return {"documents" : [], "meta" : {}}
139
182
140
183
generation_kwargs = generation_kwargs or self .generation_kwargs
141
184
@@ -148,3 +191,37 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A
148
191
doc .embedding = emb
149
192
150
193
return {"documents" : documents , "meta" : {"model" : self .model }}
194
+
195
+ @component .output_types (documents = List [Document ], meta = Dict [str , Any ])
196
+ async def run_async (self , documents : List [Document ], generation_kwargs : Optional [Dict [str , Any ]] = None ):
197
+ """
198
+ Asynchronously run an Ollama Model to compute embeddings of the provided documents.
199
+
200
+ :param documents:
201
+ Documents to be converted to an embedding.
202
+ :param generation_kwargs:
203
+ Optional arguments to pass to the Ollama generation endpoint, such as temperature,
204
+ top_p, etc. See the
205
+ [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
206
+ :returns: A dictionary with the following keys:
207
+ - `documents`: Documents with embedding information attached
208
+ - `meta`: The metadata collected during the embedding process
209
+ """
210
+
211
+ documents = self ._prepare_input (documents = documents )
212
+
213
+ if not documents :
214
+ # return early if we were passed an empty list
215
+ return {"documents" : [], "meta" : {}}
216
+
217
+ generation_kwargs = generation_kwargs or self .generation_kwargs
218
+
219
+ texts_to_embed = self ._prepare_texts_to_embed (documents = documents )
220
+ embeddings = await self ._embed_batch_async (
221
+ texts_to_embed = texts_to_embed , batch_size = self .batch_size , generation_kwargs = generation_kwargs
222
+ )
223
+
224
+ for doc , emb in zip (documents , embeddings ):
225
+ doc .embedding = emb
226
+
227
+ return {"documents" : documents , "meta" : {"model" : self .model }}
0 commit comments