Skip to content

Commit aafc1c8

Browse files
authored
feat: add run_async support for CohereTextEmbedder (#1873)
* feat: add run_async support for CohereTextEmbedder * fix: review comments
1 parent f8fdc0d commit aafc1c8

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

integrations/cohere/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,5 @@ markers = [
168168
"integration: integration tests",
169169
]
170170
log_cli = true
171+
asyncio_mode = "auto"
172+
asyncio_default_fixture_loop_scope = "class"

integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def __init__(
7575
self.timeout = timeout
7676
self.embedding_type = embedding_type or EmbeddingTypes.FLOAT
7777

78+
def _prepare_input(self, text: str) -> str:
79+
if not isinstance(text, str):
80+
msg = (
81+
"CohereTextEmbedder expects a string as input."
82+
"In case you want to embed a list of Documents, please use the CohereDocumentEmbedder."
83+
)
84+
raise TypeError(msg)
85+
86+
return text
87+
7888
def to_dict(self) -> Dict[str, Any]:
7989
"""
8090
Serializes the component to a dictionary.
@@ -114,20 +124,19 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder":
114124

115125
@component.output_types(embedding=List[float], meta=Dict[str, Any])
116126
def run(self, text: str):
117-
"""Embed text.
127+
"""
128+
Embed text.
118129
119-
:param text: the text to embed.
120-
:returns: A dictionary with the following keys:
121-
- `embedding`: the embedding of the text.
122-
- `meta`: metadata about the request.
123-
:raises TypeError: If the input is not a string.
130+
:param text:
131+
the text to embed.
132+
:returns:
133+
A dictionary with the following keys:
134+
- `embedding`: the embedding of the text.
135+
- `meta`: metadata about the request.
136+
:raises TypeError:
137+
If the input is not a string.
124138
"""
125-
if not isinstance(text, str):
126-
msg = (
127-
"CohereTextEmbedder expects a string as input."
128-
"In case you want to embed a list of Documents, please use the CohereDocumentEmbedder."
129-
)
130-
raise TypeError(msg)
139+
text = self._prepare_input(text=text)
131140

132141
# Establish connection to API
133142

@@ -158,3 +167,40 @@ def run(self, text: str):
158167
)
159168

160169
return {"embedding": embedding[0], "meta": metadata}
170+
171+
@component.output_types(embedding=List[float], meta=Dict[str, Any])
172+
async def run_async(self, text: str):
173+
"""
174+
Asynchronously embed text.
175+
176+
This is the asynchronous version of the `run` method. It has the same parameters and return values
177+
but can be used with `await` in async code.
178+
179+
:param text:
180+
Text to embed.
181+
182+
:returns:
183+
A dictionary with the following keys:
184+
- `embedding`: the embedding of the text.
185+
- `meta`: metadata about the request.
186+
187+
:raises TypeError:
188+
If the input is not a string.
189+
"""
190+
text = self._prepare_input(text=text)
191+
192+
api_key = self.api_key.resolve_value()
193+
assert api_key is not None
194+
195+
cohere_client = AsyncClientV2(
196+
api_key,
197+
base_url=self.api_base_url,
198+
timeout=self.timeout,
199+
client_name="haystack",
200+
)
201+
202+
embedding, metadata = await get_async_response(
203+
cohere_client, [text], self.model, self.input_type, self.truncate, self.embedding_type
204+
)
205+
206+
return {"embedding": embedding[0], "meta": metadata}

integrations/cohere/tests/test_text_embedder.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,5 +116,21 @@ def test_run_wrong_input_format(self):
116116
def test_run(self):
117117
embedder = CohereTextEmbedder()
118118
text = "The food was delicious"
119-
result = embedder.run(text)
119+
result = embedder.run(text=text)
120+
121+
assert len(result["embedding"]) == 4096
122+
assert all(isinstance(x, float) for x in result["embedding"])
123+
124+
@pytest.mark.asyncio
125+
@pytest.mark.skipif(
126+
not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None),
127+
reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.",
128+
)
129+
@pytest.mark.integration
130+
async def test_run_async(self):
131+
embedder = CohereTextEmbedder()
132+
text = "The food was delicious"
133+
result = await embedder.run_async(text=text)
134+
135+
assert len(result["embedding"]) == 4096
120136
assert all(isinstance(x, float) for x in result["embedding"])

0 commit comments

Comments
 (0)