Skip to content

Commit 00354a6

Browse files
anakin87masci
andauthored
Qdrant - add hybrid retriever (deepset-ai#675)
* qdrant hybrid retriever * Apply suggestions from code review Co-authored-by: Massimiliano Pippi <mpippi@gmail.com> * use fixture * scope session and little fixes * fix linting * exception handling --------- Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
1 parent 547c375 commit 00354a6

File tree

6 files changed

+413
-25
lines changed

6 files changed

+413
-25
lines changed

integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from .retriever import QdrantEmbeddingRetriever, QdrantSparseEmbeddingRetriever
5+
from .retriever import QdrantEmbeddingRetriever, QdrantHybridRetriever, QdrantSparseEmbeddingRetriever
66

7-
__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseEmbeddingRetriever")
7+
__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseEmbeddingRetriever", "QdrantHybridRetriever")

integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py

Lines changed: 132 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ class QdrantEmbeddingRetriever:
1919
":memory:",
2020
recreate_index=True,
2121
return_embedding=True,
22-
wait_result_from_api=True,
2322
)
23+
24+
document_store.write_documents([Document(content="test", embedding=[0.5]*768)])
25+
2426
retriever = QdrantEmbeddingRetriever(document_store=document_store)
2527
2628
# using a fake vector to keep the example simple
@@ -112,7 +114,7 @@ def run(
112114
The retrieved documents.
113115
114116
"""
115-
docs = self._document_store.query_by_embedding(
117+
docs = self._document_store._query_by_embedding(
116118
query_embedding=query_embedding,
117119
filters=filters or self._filters,
118120
top_k=top_k or self._top_k,
@@ -136,10 +138,14 @@ class QdrantSparseEmbeddingRetriever:
136138
137139
document_store = QdrantDocumentStore(
138140
":memory:",
141+
use_sparse_embeddings=True,
139142
recreate_index=True,
140143
return_embedding=True,
141-
wait_result_from_api=True,
142144
)
145+
146+
doc = Document(content="test", sparse_embedding=SparseEmbedding(indices=[0, 3, 5], values=[0.1, 0.5, 0.12]))
147+
document_store.write_documents([doc])
148+
143149
retriever = QdrantSparseEmbeddingRetriever(document_store=document_store)
144150
sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])
145151
retriever.run(query_sparse_embedding=sparse_embedding)
@@ -196,7 +202,7 @@ def to_dict(self) -> Dict[str, Any]:
196202
return d
197203

198204
@classmethod
199-
def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever":
205+
def from_dict(cls, data: Dict[str, Any]) -> "QdrantSparseEmbeddingRetriever":
200206
"""
201207
Deserializes the component from a dictionary.
202208
@@ -230,7 +236,7 @@ def run(
230236
The retrieved documents.
231237
232238
"""
233-
docs = self._document_store.query_by_sparse(
239+
docs = self._document_store._query_by_sparse(
234240
query_sparse_embedding=query_sparse_embedding,
235241
filters=filters or self._filters,
236242
top_k=top_k or self._top_k,
@@ -239,3 +245,124 @@ def run(
239245
)
240246

241247
return {"documents": docs}
248+
249+
250+
@component
251+
class QdrantHybridRetriever:
252+
"""
253+
A component for retrieving documents from an QdrantDocumentStore using both dense and sparse vectors
254+
and fusing the results using Reciprocal Rank Fusion.
255+
256+
Usage example:
257+
```python
258+
from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
259+
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
260+
from haystack.dataclasses.sparse_embedding import SparseEmbedding
261+
262+
document_store = QdrantDocumentStore(
263+
":memory:",
264+
use_sparse_embeddings=True,
265+
recreate_index=True,
266+
return_embedding=True,
267+
wait_result_from_api=True,
268+
)
269+
270+
doc = Document(content="test",
271+
embedding=[0.5]*768,
272+
sparse_embedding=SparseEmbedding(indices=[0, 3, 5], values=[0.1, 0.5, 0.12]))
273+
274+
document_store.write_documents([doc])
275+
276+
retriever = QdrantHybridRetriever(document_store=document_store)
277+
embedding = [0.1]*768
278+
sparse_embedding = SparseEmbedding(indices=[0, 1, 2, 3], values=[0.1, 0.8, 0.05, 0.33])
279+
retriever.run(query_embedding=embedding, query_sparse_embedding=sparse_embedding)
280+
```
281+
"""
282+
283+
def __init__(
284+
self,
285+
document_store: QdrantDocumentStore,
286+
filters: Optional[Dict[str, Any]] = None,
287+
top_k: int = 10,
288+
return_embedding: bool = False,
289+
):
290+
"""
291+
Create a QdrantHybridRetriever component.
292+
293+
:param document_store: An instance of QdrantDocumentStore.
294+
:param filters: A dictionary with filters to narrow down the search space.
295+
:param top_k: The maximum number of documents to retrieve.
296+
:param return_embedding: Whether to return the embeddings of the retrieved Documents.
297+
298+
:raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore.
299+
"""
300+
301+
if not isinstance(document_store, QdrantDocumentStore):
302+
msg = "document_store must be an instance of QdrantDocumentStore"
303+
raise ValueError(msg)
304+
305+
self._document_store = document_store
306+
self._filters = filters
307+
self._top_k = top_k
308+
self._return_embedding = return_embedding
309+
310+
def to_dict(self) -> Dict[str, Any]:
311+
"""
312+
Serializes the component to a dictionary.
313+
314+
:returns:
315+
Dictionary with serialized data.
316+
"""
317+
return default_to_dict(
318+
self,
319+
document_store=self._document_store.to_dict(),
320+
filters=self._filters,
321+
top_k=self._top_k,
322+
return_embedding=self._return_embedding,
323+
)
324+
325+
@classmethod
326+
def from_dict(cls, data: Dict[str, Any]) -> "QdrantHybridRetriever":
327+
"""
328+
Deserializes the component from a dictionary.
329+
330+
:param data:
331+
Dictionary to deserialize from.
332+
:returns:
333+
Deserialized component.
334+
"""
335+
document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"])
336+
data["init_parameters"]["document_store"] = document_store
337+
return default_from_dict(cls, data)
338+
339+
@component.output_types(documents=List[Document])
340+
def run(
341+
self,
342+
query_embedding: List[float],
343+
query_sparse_embedding: SparseEmbedding,
344+
filters: Optional[Dict[str, Any]] = None,
345+
top_k: Optional[int] = None,
346+
return_embedding: Optional[bool] = None,
347+
):
348+
"""
349+
Run the Sparse Embedding Retriever on the given input data.
350+
351+
:param query_embedding: Dense embedding of the query.
352+
:param query_sparse_embedding: Sparse embedding of the query.
353+
:param filters: A dictionary with filters to narrow down the search space.
354+
:param top_k: The maximum number of documents to return.
355+
:param return_embedding: Whether to return the embedding of the retrieved Documents.
356+
:returns:
357+
The retrieved documents.
358+
359+
"""
360+
docs = self._document_store._query_hybrid(
361+
query_embedding=query_embedding,
362+
query_sparse_embedding=query_sparse_embedding,
363+
filters=filters or self._filters,
364+
top_k=top_k or self._top_k,
365+
return_embedding=return_embedding or self._return_embedding,
366+
)
367+
368+
return {"documents": docs}

integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from qdrant_client import grpc
1717
from qdrant_client.http import models as rest
1818
from qdrant_client.http.exceptions import UnexpectedResponse
19+
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
1920
from tqdm import tqdm
2021

2122
from .converters import (
@@ -307,7 +308,7 @@ def get_documents_by_id(
307308
)
308309
return documents
309310

310-
def query_by_sparse(
311+
def _query_by_sparse(
311312
self,
312313
query_sparse_embedding: SparseEmbedding,
313314
filters: Optional[Dict[str, Any]] = None,
@@ -349,7 +350,7 @@ def query_by_sparse(
349350
document.score = score
350351
return results
351352

352-
def query_by_embedding(
353+
def _query_by_embedding(
353354
self,
354355
query_embedding: List[float],
355356
filters: Optional[Dict[str, Any]] = None,
@@ -383,6 +384,86 @@ def query_by_embedding(
383384
document.score = score
384385
return results
385386

387+
def _query_hybrid(
388+
self,
389+
query_embedding: List[float],
390+
query_sparse_embedding: SparseEmbedding,
391+
filters: Optional[Dict[str, Any]] = None,
392+
top_k: int = 10,
393+
return_embedding: bool = False,
394+
) -> List[Document]:
395+
"""
396+
Retrieves documents based on dense and sparse embeddings and fuses the results using Reciprocal Rank Fusion.
397+
398+
This method is not part of the public interface of `QdrantDocumentStore` and shouldn't be used directly.
399+
Use the `QdrantHybridRetriever` instead.
400+
401+
:param query_embedding: Dense embedding of the query.
402+
:param query_sparse_embedding: Sparse embedding of the query.
403+
:param filters: Filters applied to the retrieved Documents.
404+
:param top_k: Maximum number of Documents to return.
405+
:param return_embedding: Whether to return the embeddings of the retrieved documents.
406+
407+
:returns: List of Document that are most similar to `query_embedding` and `query_sparse_embedding`.
408+
409+
:raises QdrantStoreError:
410+
If the Document Store was initialized with `use_sparse_embeddings=False`.
411+
"""
412+
413+
# This implementation is based on the code from the Python Qdrant client:
414+
# https://github.com/qdrant/qdrant-client/blob/8e3ea58f781e4110d11c0a6985b5e6bb66b85d33/qdrant_client/qdrant_fastembed.py#L519
415+
if not self.use_sparse_embeddings:
416+
message = (
417+
"You are trying to query using sparse embeddings, but the Document Store "
418+
"was initialized with `use_sparse_embeddings=False`. "
419+
)
420+
raise QdrantStoreError(message)
421+
422+
qdrant_filters = convert_filters_to_qdrant(filters)
423+
424+
sparse_request = rest.SearchRequest(
425+
vector=rest.NamedSparseVector(
426+
name=SPARSE_VECTORS_NAME,
427+
vector=rest.SparseVector(
428+
indices=query_sparse_embedding.indices,
429+
values=query_sparse_embedding.values,
430+
),
431+
),
432+
filter=qdrant_filters,
433+
limit=top_k,
434+
with_payload=True,
435+
with_vector=return_embedding,
436+
)
437+
438+
dense_request = rest.SearchRequest(
439+
vector=rest.NamedVector(
440+
name=DENSE_VECTORS_NAME,
441+
vector=query_embedding,
442+
),
443+
filter=qdrant_filters,
444+
limit=top_k,
445+
with_payload=True,
446+
with_vector=return_embedding,
447+
)
448+
449+
try:
450+
dense_request_response, sparse_request_response = self.client.search_batch(
451+
collection_name=self.index, requests=[dense_request, sparse_request]
452+
)
453+
except Exception as e:
454+
msg = "Error during hybrid search"
455+
raise QdrantStoreError(msg) from e
456+
457+
try:
458+
points = reciprocal_rank_fusion(responses=[dense_request_response, sparse_request_response], limit=top_k)
459+
except Exception as e:
460+
msg = "Error while applying Reciprocal Rank Fusion"
461+
raise QdrantStoreError(msg) from e
462+
463+
results = [convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) for point in points]
464+
465+
return results
466+
386467
def _get_distance(self, similarity: str) -> rest.Distance:
387468
try:
388469
return self.SIMILARITY[similarity]

integrations/qdrant/tests/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import numpy as np
2+
import pytest
3+
from haystack.dataclasses import SparseEmbedding
4+
5+
6+
@pytest.fixture(scope="session")
7+
def generate_sparse_embedding():
8+
"""
9+
This fixture returns a function that generates a random SparseEmbedding each time it is called.
10+
"""
11+
12+
def _generate_random_sparse_embedding():
13+
random_indice_length = np.random.randint(3, 15)
14+
indices = list(range(random_indice_length))
15+
values = [np.random.random_sample() for _ in range(random_indice_length)]
16+
return SparseEmbedding(indices=indices, values=values)
17+
18+
return _generate_random_sparse_embedding

0 commit comments

Comments
 (0)