1
1
from abc import ABC , abstractmethod
2
2
from collections import Counter
3
- from dataclasses import dataclass
4
3
from typing import ClassVar , TypeVar
5
4
6
5
import tiktoken
6
+ from pydantic import BaseModel
7
7
8
8
from ragbits .core import embeddings
9
9
from ragbits .core .options import Options
13
13
SparseEmbeddingsOptionsT = TypeVar ("SparseEmbeddingsOptionsT" , bound = Options )
14
14
15
15
16
- @dataclass
17
- class SparseVector :
16
+ class SparseVector (BaseModel ):
18
17
"""Sparse Vector representation"""
19
18
20
- non_zero_dims : list [int ]
21
- non_zero_vals : list [int ]
22
- dim : int
19
+ indices : list [int ]
20
+ values : list [float ]
23
21
24
22
def __post_init__ (self ) -> None :
25
- if len (self .non_zero_dims ) != len (self .non_zero_vals ):
23
+ if len (self .indices ) != len (self .values ):
26
24
raise ValueError ("There should be the same number of non-zero values as non-zero positions" )
27
- if any (dim >= self .dim or dim < 0 for dim in self .non_zero_dims ):
28
- raise ValueError ("Indexes should be in the range of the vector dim" )
29
25
30
26
def __repr__ (self ) -> str :
31
- return f"SparseVector(non_zero_dims ={ self .non_zero_dims } , non_zero_vals ={ self .non_zero_vals } , dim= { self . dim } )"
27
+ return f"SparseVector(indices ={ self .indices } , values ={ self .values } )"
32
28
33
29
34
30
class SparseEmbeddings (ConfigurableComponent [SparseEmbeddingsOptionsT ], ABC ):
@@ -39,7 +35,7 @@ class SparseEmbeddings(ConfigurableComponent[SparseEmbeddingsOptionsT], ABC):
39
35
configuration_key : ClassVar = "sparse_embedder"
40
36
41
37
@abstractmethod
42
- def embed_text (self , texts : list [str ], options : SparseEmbeddingsOptionsT | None = None ) -> list [SparseVector ]:
38
+ async def embed_text (self , texts : list [str ], options : SparseEmbeddingsOptionsT | None = None ) -> list [SparseVector ]:
43
39
"""Transforms a list of texts into sparse vectors"""
44
40
45
41
@@ -52,11 +48,11 @@ class BagOfTokensOptions(Options):
52
48
53
49
54
50
class BagOfTokens (SparseEmbeddings [BagOfTokensOptions ]):
55
- """BagofTokens implementations of sparse Embeddings interface"""
51
+ """BagOfTokens implementations of sparse Embeddings interface"""
56
52
57
53
options_cls = BagOfTokensOptions
58
54
59
- def embed_text (self , texts : list [str ], options : BagOfTokensOptions | None = None ) -> list [SparseVector ]:
55
+ async def embed_text (self , texts : list [str ], options : BagOfTokensOptions | None = None ) -> list [SparseVector ]:
60
56
"""
61
57
Transforms a list of texts into sparse vectors using bag-of-tokens representation.
62
58
@@ -73,12 +69,14 @@ def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None
73
69
raise ValueError ("Please specify only one of encoding_name or model_name" )
74
70
if not (merged_options .encoding_name or merged_options .model_name ):
75
71
raise ValueError ("Either encoding_name or model_name needs to be specified" )
72
+
76
73
if merged_options .encoding_name :
77
74
encoder = tiktoken .get_encoding (encoding_name = merged_options .encoding_name )
78
- if merged_options .model_name :
75
+ elif merged_options .model_name :
79
76
encoder = tiktoken .encoding_for_model (model_name = merged_options .model_name )
77
+ else :
78
+ raise ValueError ("Either encoding_name or model_name needs to be specified" )
80
79
81
- dim = encoder .n_vocab
82
80
min_token_count = merged_options .min_token_count or float ("-inf" )
83
81
for text in texts :
84
82
tokens = encoder .encode (text )
@@ -90,7 +88,7 @@ def embed_text(self, texts: list[str], options: BagOfTokensOptions | None = None
90
88
if count < min_token_count :
91
89
continue
92
90
non_zero_dims .append (token )
93
- non_zero_vals .append (count )
91
+ non_zero_vals .append (float ( count ) )
94
92
95
- vectors .append (SparseVector (non_zero_dims , non_zero_vals , dim ))
93
+ vectors .append (SparseVector (indices = non_zero_dims , values = non_zero_vals ))
96
94
return vectors
0 commit comments