Skip to content

Commit 3d69ccd

Browse files
feat: adding an HybridRetriever as a Supercomponent having OpenSearch as the document store (#1701)
* adding tests * linting and typing * adding env variable * env variable * extending docstring * removing generation part * updating tests * adding a run test with mocked sentence_transformers * fixing format
1 parent 4b0b586 commit 3d69ccd

File tree

9 files changed

+527
-6
lines changed

9 files changed

+527
-6
lines changed

.github/workflows/opensearch.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,16 @@ jobs:
5858
- name: Run tests
5959
run: hatch run cov-retry
6060

61+
- name: Run unit tests with lowest direct dependencies
62+
run: |
63+
hatch run uv pip compile pyproject.toml --resolution lowest-direct --output-file requirements_lowest_direct.txt
64+
hatch run uv pip install -r requirements_lowest_direct.txt
65+
hatch run test -m "not integration"
66+
6167
- name: Nightly - run unit tests with Haystack main branch
6268
if: github.event_name == 'schedule'
6369
run: |
70+
hatch env prune
6471
hatch run uv pip install git+https://github.com/deepset-ai/haystack.git@main
6572
hatch run cov-retry -m "not integration"
6673

integrations/opensearch/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai>=2.11.0", "opensearch-py[async]>=2,<3"]
26+
dependencies = [
27+
"haystack-ai>=2.14.0",
28+
"opensearch-py[async]>=2.4.0,<3"]
2729

2830
[project.urls]
2931
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch#readme"
@@ -52,6 +54,7 @@ dependencies = [
5254
"haystack-pydoc-tools",
5355
"boto3",
5456
]
57+
5558
[tool.hatch.envs.default.scripts]
5659
test = "pytest {args:tests}"
5760
test-cov = "coverage run -m pytest {args:tests}"
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
45
from .bm25_retriever import OpenSearchBM25Retriever
56
from .embedding_retriever import OpenSearchEmbeddingRetriever
7+
from .open_search_hybrid_retriever import OpenSearchHybridRetriever
68

7-
__all__ = ["OpenSearchBM25Retriever", "OpenSearchEmbeddingRetriever"]
9+
__all__ = ["OpenSearchBM25Retriever", "OpenSearchEmbeddingRetriever", "OpenSearchHybridRetriever"]

integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
45
from typing import Any, Dict, List, Optional, Union
56

67
from haystack import component, default_from_dict, default_to_dict, logging
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
# SPDX-FileCopyrightText: 2023-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
6+
7+
from haystack import DeserializationError, Pipeline, default_from_dict, default_to_dict, logging, super_component
8+
from haystack.components.embedders.types import TextEmbedder
9+
from haystack.components.joiners import DocumentJoiner
10+
from haystack.components.joiners.document_joiner import JoinMode
11+
from haystack.core.serialization import component_from_dict, import_class_by_name
12+
from haystack.document_stores.types import FilterPolicy
13+
14+
from haystack_integrations.components.retrievers.opensearch import OpenSearchBM25Retriever, OpenSearchEmbeddingRetriever
15+
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@super_component
21+
class OpenSearchHybridRetriever:
22+
"""
23+
A hybrid retriever that combines embedding-based and keyword-based retrieval from OpenSearch.
24+
25+
Example usage:
26+
27+
Make sure you have "sentence-transformers>=3.0.0":
28+
29+
pip install haystack-ai datasets "sentence-transformers>=3.0.0"
30+
31+
32+
And OpenSearch running. You can run OpenSearch with Docker:
33+
34+
docker run -d --name opensearch-nosec -p 9200:9200 -p 9600:9600 -e "discovery.type=single-node"
35+
-e "DISABLE_SECURITY_PLUGIN=true" opensearchproject/opensearch:2.12.0
36+
37+
```python
38+
from haystack import Document
39+
from haystack.components.embedders import SentenceTransformersTextEmbedder, SentenceTransformersDocumentEmbedder
40+
from haystack_integrations.components.retrievers.opensearch import OpenSearchHybridRetriever
41+
from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore
42+
43+
# Initialize the document store
44+
doc_store = OpenSearchDocumentStore(
45+
hosts=["<http://localhost:9200>"],
46+
index="document_store",
47+
embedding_dim=384,
48+
)
49+
50+
# Create some sample documents
51+
docs = [
52+
Document(content="Machine learning is a subset of artificial intelligence."),
53+
Document(content="Deep learning is a subset of machine learning."),
54+
Document(content="Natural language processing is a field of AI."),
55+
Document(content="Reinforcement learning is a type of machine learning."),
56+
Document(content="Supervised learning is a type of machine learning."),
57+
]
58+
59+
# Embed the documents and add them to the document store
60+
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
61+
doc_embedder.warm_up()
62+
docs = doc_embedder.run(docs)
63+
doc_store.write_documents(docs['documents'])
64+
65+
# Initialize some haystack text embedder, in this case the SentenceTransformersTextEmbedder
66+
embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
67+
68+
# Initialize the hybrid retriever
69+
retriever = OpenSearchHybridRetriever(
70+
document_store=doc_store,
71+
embedder=embedder,
72+
top_k_bm25=3,
73+
top_k_embedding=3,
74+
join_mode="reciprocal_rank_fusion"
75+
)
76+
77+
# Run the retriever
78+
results = retriever.run(query="What is reinforcement learning?", filters_bm25=None, filters_embedding=None)
79+
80+
>> results['documents']
81+
{'documents': [Document(id=..., content: 'Reinforcement learning is a type of machine learning.', score: 1.0),
82+
Document(id=..., content: 'Supervised learning is a type of machine learning.', score: 0.9760624679979518),
83+
Document(id=..., content: 'Deep learning is a subset of machine learning.', score: 0.4919354838709677),
84+
Document(id=..., content: 'Machine learning is a subset of artificial intelligence.', score: 0.4841269841269841)]}
85+
```
86+
"""
87+
88+
def __init__(
89+
self,
90+
document_store: OpenSearchDocumentStore,
91+
*,
92+
embedder: TextEmbedder,
93+
# OpenSearchBM25Retriever
94+
filters_bm25: Optional[Dict[str, Any]] = None,
95+
fuzziness: Union[int, str] = "AUTO",
96+
top_k_bm25: int = 10,
97+
scale_score: bool = False,
98+
all_terms_must_match: bool = False,
99+
filter_policy_bm25: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
100+
custom_query_bm25: Optional[Dict[str, Any]] = None,
101+
# OpenSearchEmbeddingRetriever
102+
filters_embedding: Optional[Dict[str, Any]] = None,
103+
top_k_embedding: int = 10,
104+
filter_policy_embedding: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
105+
custom_query_embedding: Optional[Dict[str, Any]] = None,
106+
# DocumentJoiner
107+
join_mode: Union[str, JoinMode] = JoinMode.RECIPROCAL_RANK_FUSION,
108+
weights: Optional[List[float]] = None,
109+
top_k: Optional[int] = None,
110+
sort_by_score: bool = True,
111+
# extra kwargs
112+
**kwargs,
113+
):
114+
"""
115+
Initialize the OpenSearchHybridRetriever, a super component to retrieve documents from OpenSearch using
116+
both embedding-based and keyword-based retrieval methods.
117+
118+
We don't explicitly define all the init parameters of the components in the constructor, for each
119+
of the components, since that would be around 20+ parameters. Instead, we define the most important ones
120+
and pass the rest as kwargs. This is to keep the constructor clean and easy to read.
121+
122+
If you need to pass extra parameters to the components, you can do so by passing them as kwargs. It expects
123+
a dictionary with the component name as the key and the parameters as the value. The component name should be:
124+
125+
- "bm25_retriever" -> OpenSearchBM25Retriever
126+
- "embedding_retriever" -> OpenSearchEmbeddingRetriever
127+
128+
:param document_store:
129+
The OpenSearchDocumentStore to use for retrieval.
130+
:param embedder:
131+
A TextEmbedder to use for embedding the query.
132+
See `haystack.components.embedders.types.protocol.TextEmbedder` for more information.
133+
:param filters_bm25:
134+
Filters for the BM25 retriever.
135+
:param fuzziness:
136+
The fuzziness for the BM25 retriever.
137+
:param top_k_bm25:
138+
The number of results to return from the BM25 retriever.
139+
:param scale_score:
140+
Whether to scale the score for the BM25 retriever.
141+
:param all_terms_must_match:
142+
Whether all terms must match for the BM25 retriever.
143+
:param filter_policy_bm25:
144+
The filter policy for the BM25 retriever.
145+
:param custom_query_bm25:
146+
A custom query for the BM25 retriever.
147+
:param filters_embedding:
148+
Filters for the embedding retriever.
149+
:param top_k_embedding:
150+
The number of results to return from the embedding retriever.
151+
:param filter_policy_embedding:
152+
The filter policy for the embedding retriever.
153+
:param custom_query_embedding:
154+
A custom query for the embedding retriever.
155+
:param join_mode:
156+
The mode to use for joining the results from the BM25 and embedding retrievers.
157+
:param weights:
158+
The weights for the joiner.
159+
:param top_k:
160+
The number of results to return from the joiner.
161+
:param sort_by_score:
162+
Whether to sort the results by score.
163+
:param **kwargs:
164+
Additional keyword arguments. Use the following keys to pass extra parameters to the retrievers:
165+
- "bm25_retriever" -> OpenSearchBM25Retriever
166+
- "embedding_retriever" -> OpenSearchEmbeddingRetriever
167+
168+
169+
"""
170+
self.document_store = document_store
171+
self.embedder = embedder
172+
173+
# OpenSearchBM25Retriever
174+
self.filters_bm25 = filters_bm25
175+
self.fuzziness = fuzziness
176+
self.top_k_bm25 = top_k_bm25
177+
self.scale_score = scale_score
178+
self.all_terms_must_match = all_terms_must_match
179+
self.filter_policy_bm25 = filter_policy_bm25
180+
self.custom_query_bm25 = custom_query_bm25
181+
182+
# OpenSearchEmbeddingRetriever
183+
self.filters_embedding = filters_embedding
184+
self.top_k_embedding = top_k_embedding
185+
self.filter_policy_embedding = filter_policy_embedding
186+
self.custom_query_embedding = custom_query_embedding
187+
188+
# DocumentJoiner
189+
self.join_mode = join_mode
190+
self.weights = weights
191+
self.top_k = top_k
192+
self.sort_by_score = sort_by_score
193+
194+
init_args = {
195+
"bm25_retriever": {
196+
"document_store": self.document_store,
197+
"filters": self.filters_bm25,
198+
"fuzziness": self.fuzziness,
199+
"top_k": self.top_k_bm25,
200+
"scale_score": self.scale_score,
201+
"all_terms_must_match": self.all_terms_must_match,
202+
"filter_policy": self.filter_policy_bm25,
203+
"custom_query": self.custom_query_bm25,
204+
},
205+
"embedding_retriever": {
206+
"document_store": self.document_store,
207+
"filters": self.filters_embedding,
208+
"top_k": self.top_k_embedding,
209+
"filter_policy": self.filter_policy_embedding,
210+
"custom_query": self.custom_query_embedding,
211+
},
212+
"document_joiner": {
213+
"join_mode": self.join_mode,
214+
"weights": self.weights,
215+
"top_k": self.top_k,
216+
"sort_by_score": self.sort_by_score,
217+
},
218+
}
219+
220+
for k in kwargs:
221+
if k not in ["bm25_retriever", "embedding_retriever"]:
222+
msg = f"valid extra args are only: 'bm25_retriever' and 'embedding_retriever'. Found: {k}"
223+
raise ValueError(msg)
224+
225+
self.extra_args = kwargs
226+
227+
# handle extra kwargs for the bm25 and embedding retrievers and the doc store as init param
228+
if "bm25_retriever" in kwargs:
229+
init_args["bm25_retriever"].update(kwargs["bm25_retriever"])
230+
init_args["bm25_retriever"]["document_store"] = self.document_store
231+
if "embedding_retriever" in kwargs:
232+
init_args["embedding_retriever"].update(kwargs["embedding_retriever"])
233+
init_args["embedding_retriever"]["document_store"] = self.document_store
234+
235+
self.pipeline = self._create_pipeline(init_args)
236+
237+
if TYPE_CHECKING:
238+
239+
def warm_up(self) -> None: ...
240+
241+
def run(self, query: str, filters_bm25=None, filters_embedding=None) -> Dict[str, Any]: ...
242+
243+
def _create_pipeline(self, data: dict[str, Any]) -> Pipeline:
244+
"""
245+
Create the pipeline for the OpenSearchHybridRetriever.
246+
"""
247+
embedding_retriever = OpenSearchEmbeddingRetriever(**data["embedding_retriever"])
248+
bm25_retriever = OpenSearchBM25Retriever(**data["bm25_retriever"])
249+
document_joiner = DocumentJoiner(**data["document_joiner"])
250+
251+
hybrid_retrieval = Pipeline()
252+
hybrid_retrieval.add_component("text_embedder", self.embedder)
253+
hybrid_retrieval.add_component("embedding_retriever", embedding_retriever)
254+
hybrid_retrieval.add_component("bm25_retriever", bm25_retriever)
255+
hybrid_retrieval.add_component("document_joiner", document_joiner)
256+
257+
hybrid_retrieval.connect("text_embedder.embedding", "embedding_retriever.query_embedding")
258+
hybrid_retrieval.connect("bm25_retriever", "document_joiner")
259+
hybrid_retrieval.connect("embedding_retriever", "document_joiner")
260+
261+
# Define how pipeline inputs/outputs map to subcomponent inputs/outputs
262+
self.input_mapping = {
263+
# The pipeline input "query" feeds into each of the retrievers
264+
"query": ["text_embedder.text", "bm25_retriever.query"],
265+
}
266+
self.output_mapping = {"document_joiner.documents": "documents"}
267+
268+
return hybrid_retrieval
269+
270+
def to_dict(self):
271+
"""
272+
Serialize OpenSearchHybridRetriever to a dictionary.
273+
274+
:returns:
275+
Dictionary with serialized data.
276+
"""
277+
return default_to_dict(
278+
self,
279+
# DocumentStore
280+
document_store=self.document_store.to_dict(),
281+
embedder=self.embedder.to_dict(),
282+
filters_bm25=self.filters_bm25,
283+
fuzziness=self.fuzziness,
284+
top_k_bm25=self.top_k_bm25,
285+
scale_score=self.scale_score,
286+
all_terms_must_match=self.all_terms_must_match,
287+
filter_policy_bm25=self.filter_policy_bm25.value,
288+
custom_query_bm25=self.custom_query_bm25,
289+
# OpenSearchEmbeddingRetriever
290+
filters_embedding=self.filters_embedding,
291+
top_k_embedding=self.top_k_embedding,
292+
filter_policy_embedding=self.filter_policy_embedding.value,
293+
custom_query_embedding=self.custom_query_embedding,
294+
# DocumentJoiner
295+
join_mode=self.join_mode.value,
296+
weights=self.weights,
297+
top_k=self.top_k,
298+
sort_by_score=self.sort_by_score,
299+
# extra kwargs
300+
**self.extra_args,
301+
)
302+
303+
@classmethod
304+
def from_dict(cls, data):
305+
# deserialize the document store
306+
doc_store = OpenSearchDocumentStore.from_dict(data["init_parameters"]["document_store"])
307+
data["init_parameters"]["document_store"] = doc_store
308+
309+
# deserialize the embedder
310+
try:
311+
text_embedder_class = import_class_by_name(data["init_parameters"]["embedder"]["type"])
312+
except ImportError as e:
313+
msg = f"Class '{data['init_parameters']['embedder']['type']}' not correctly imported"
314+
raise DeserializationError(msg) from e
315+
316+
data["init_parameters"]["embedder"] = component_from_dict(
317+
cls=text_embedder_class, data=data["init_parameters"]["embedder"], name="embedder"
318+
)
319+
320+
# deserialize the embedders filtering policy
321+
if "filter_policy_bm25" in data["init_parameters"]:
322+
filter_policy_bm25 = FilterPolicy.from_str(data["init_parameters"]["filter_policy_bm25"])
323+
data["init_parameters"]["filter_policy_bm25"] = filter_policy_bm25
324+
325+
if "filter_policy_embedding" in data["init_parameters"]:
326+
filter_policy_embedding = FilterPolicy.from_str(data["init_parameters"]["filter_policy_embedding"])
327+
data["init_parameters"]["filter_policy_embedding"] = filter_policy_embedding
328+
329+
if "join_mode" in data["init_parameters"]:
330+
join_mode = JoinMode.from_str(data["init_parameters"]["join_mode"])
331+
data["init_parameters"]["join_mode"] = join_mode
332+
333+
return default_from_dict(cls, data)

0 commit comments

Comments
 (0)