Skip to content

Commit a760679

Browse files
authored
feat: add fastembed embeddings (#374)
1 parent 9e51faa commit a760679

File tree

9 files changed

+340
-104
lines changed

9 files changed

+340
-104
lines changed

.libraries-whitelist.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ chroma-hnswlib
55
rouge
66
distilabel
77
rerankers
8+
py_rust_stemmers

packages/ragbits-core/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Add support to fastembed dense & sparse embeddings.
56
- Fix: changed variable type from Filter to WhereQuery in the Qdrant vector store in list method.
67

78
## 0.8.0 (2025-01-29)

packages/ragbits-core/pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ local = [
5454
"transformers~=4.44.2",
5555
"numpy~=1.26.0"
5656
]
57+
fastembed = [
58+
"fastembed>=0.4.2"
59+
]
5760
lab = [
5861
"gradio~=4.44.0",
5962
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from fastembed import SparseTextEmbedding, TextEmbedding
2+
3+
from ragbits.core.embeddings import Embeddings, EmbeddingsOptionsT, SparseEmbeddings
4+
from ragbits.core.embeddings.sparse import SparseVector
5+
from ragbits.core.options import Options
6+
7+
8+
class FastEmbedOptions(Options):
9+
"""
10+
Dataclass that represents available call options for the LocalEmbeddings client.
11+
"""
12+
13+
batch_size: int = 256
14+
parallel: int | None = None
15+
16+
17+
class FastEmbedEmbeddings(Embeddings[FastEmbedOptions]):
18+
"""
19+
Class for creating dense text embeddings using FastEmbed library.
20+
For more information, see the [FastEmbed GitHub](https://github.com/qdrant/fastembed).
21+
"""
22+
23+
options_cls = FastEmbedOptions
24+
_model: TextEmbedding
25+
26+
def __init__(self, model_name: str, default_options: FastEmbedOptions | None = None):
27+
super().__init__(default_options=default_options)
28+
self.model_name = model_name
29+
self._model = TextEmbedding(model_name)
30+
31+
async def embed_text(self, data: list[str], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
32+
"""
33+
Embeds a list of strings into a list of embeddings.
34+
35+
Args:
36+
data: List of strings to get embeddings for.
37+
options: Additional options to pass to the embedding model.
38+
39+
Returns:
40+
List of embeddings for the given strings.
41+
"""
42+
merged_options = (self.default_options | options) if options else self.default_options
43+
44+
return [[float(x) for x in result] for result in self._model.embed(data, **merged_options.dict())]
45+
46+
47+
class FastEmbedSparseEmbeddings(SparseEmbeddings[FastEmbedOptions]):
48+
"""
49+
Class for creating sparse text embeddings using FastEmbed library.
50+
For more information, see the [FastEmbed GitHub](https://github.com/qdrant/fastembed).
51+
"""
52+
53+
options_cls = FastEmbedOptions
54+
_model: SparseTextEmbedding
55+
56+
def __init__(self, model_name: str, default_options: FastEmbedOptions | None = None):
57+
super().__init__(default_options=default_options)
58+
self.model_name = model_name
59+
self._model = SparseTextEmbedding(model_name)
60+
61+
async def embed_text(self, data: list[str], options: EmbeddingsOptionsT | None = None) -> list[SparseVector]:
62+
"""
63+
Embeds a list of strings into a list of sparse embeddings.
64+
65+
Args:
66+
data: List of strings to get embeddings for.
67+
options: Additional options to pass to the embedding model.
68+
69+
Returns:
70+
List of embeddings for the given strings.
71+
"""
72+
merged_options = (self.default_options | options) if options else self.default_options
73+
74+
return [
75+
SparseVector(values=[float(x) for x in result.values], indices=[int(x) for x in result.indices])
76+
for result in self._model.embed(data, **merged_options.dict())
77+
]

packages/ragbits-core/src/ragbits/core/embeddings/sparse.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from abc import ABC, abstractmethod
22
from collections import Counter
3-
from dataclasses import dataclass
43
from typing import ClassVar, TypeVar
54

65
import tiktoken
6+
from pydantic import BaseModel
77

88
from ragbits.core import embeddings
99
from ragbits.core.options import Options
@@ -13,22 +13,18 @@
1313
SparseEmbeddingsOptionsT = TypeVar("SparseEmbeddingsOptionsT", bound=Options)
1414

1515

16-
@dataclass
17-
class SparseVector:
16+
class SparseVector(BaseModel):
1817
"""Sparse Vector representation"""
1918

20-
non_zero_dims: list[int]
21-
non_zero_vals: list[int]
22-
dim: int
19+
indices: list[int]
20+
values: list[float]
2321

2422
def __post_init__(self) -> None:
25-
if len(self.non_zero_dims) != len(self.non_zero_vals):
23+
if len(self.indices) != len(self.values):
2624
raise ValueError("There should be the same number of non-zero values as non-zero positions")
27-
if any(dim >= self.dim or dim < 0 for dim in self.non_zero_dims):
28-
raise ValueError("Indexes should be in the range of the vector dim")
2925

3026
def __repr__(self) -> str:
31-
return f"SparseVector(non_zero_dims={self.non_zero_dims}, non_zero_vals={self.non_zero_vals}, dim={self.dim})"
27+
return f"SparseVector(indices={self.indices}, values={self.values})"
3228

3329

3430
class SparseEmbeddings(ConfigurableComponent[SparseEmbeddingsOptionsT], ABC):
@@ -39,7 +35,7 @@ class SparseEmbeddings(ConfigurableComponent[SparseEmbeddingsOptionsT], ABC):
3935
configuration_key: ClassVar = "sparse_embedder"
4036

4137
@abstractmethod
42-
def embed_text(self, texts: list[str], options: SparseEmbeddingsOptionsT | None = None) -> list[SparseVector]:
38+
async def embed_text(self, texts: list[str], options: SparseEmbeddingsOptionsT | None = None) -> list[SparseVector]:
4339
"""Transforms a list of texts into sparse vectors"""
4440

4541

@@ -52,11 +48,11 @@ class BagOfTokensOptions(Options):
5248

5349

5450
class BagOfTokens(SparseEmbeddings[BagOfTokensOptions]):
55-
"""BagofTokens implementations of sparse Embeddings interface"""
51+
"""BagOfTokens implementations of sparse Embeddings interface"""
5652

5753
options_cls = BagOfTokensOptions
5854

59-
def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None) -> list[SparseVector]:
55+
async def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None) -> list[SparseVector]:
6056
"""
6157
Transforms a list of texts into sparse vectors using bag-of-tokens representation.
6258
@@ -73,12 +69,14 @@ def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None
7369
raise ValueError("Please specify only one of encoding_name or model_name")
7470
if not (merged_options.encoding_name or merged_options.model_name):
7571
raise ValueError("Either encoding_name or model_name needs to be specified")
72+
7673
if merged_options.encoding_name:
7774
encoder = tiktoken.get_encoding(encoding_name=merged_options.encoding_name)
78-
if merged_options.model_name:
75+
elif merged_options.model_name:
7976
encoder = tiktoken.encoding_for_model(model_name=merged_options.model_name)
77+
else:
78+
raise ValueError("Either encoding_name or model_name needs to be specified")
8079

81-
dim = encoder.n_vocab
8280
min_token_count = merged_options.min_token_count or float("-inf")
8381
for text in texts:
8482
tokens = encoder.encode(text)
@@ -90,7 +88,7 @@ def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None
9088
if count < min_token_count:
9189
continue
9290
non_zero_dims.append(token)
93-
non_zero_vals.append(count)
91+
non_zero_vals.append(float(count))
9492

95-
vectors.append(SparseVector(non_zero_dims, non_zero_vals, dim))
93+
vectors.append(SparseVector(indices=non_zero_dims, values=non_zero_vals))
9694
return vectors

packages/ragbits-core/tests/unit/embeddings/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from ragbits.core.embeddings.fastembed import FastEmbedEmbeddings, FastEmbedSparseEmbeddings
2+
3+
4+
async def test_fastembed_dense_embeddings():
5+
embeddings = FastEmbedEmbeddings("BAAI/bge-small-en-v1.5")
6+
result = await embeddings.embed_text(["text1"])
7+
assert len(result[0]) == 384
8+
9+
10+
async def test_fastembed_sparse_embeddings():
11+
embeddings = FastEmbedSparseEmbeddings("qdrant/bm25")
12+
result = await embeddings.embed_text(["text1"])
13+
assert len(result[0].values) == len(result[0].indices)

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ requires-python = ">=3.10"
77
dependencies = [
88
"asyncpg>=0.30.0",
99
"ragbits-cli",
10-
"ragbits-core[chroma,lab,local,otel,qdrant]",
10+
"ragbits-core[chroma,lab,fastembed,local,otel,qdrant]",
1111
"ragbits-document-search[gcs,huggingface,distributed,azure,s3]",
1212
"ragbits-evaluate[relari]",
1313
"ragbits-guardrails[openai]",
14-
"ragbits-conversations"
14+
"ragbits-conversations",
1515
]
1616

1717
[tool.uv]

0 commit comments

Comments
 (0)