7
7
from pydantic .json import pydantic_encoder
8
8
9
9
from ragbits .core .audit .traces import trace
10
- from ragbits .core .embeddings .base import Embedder , SparseVector
10
+ from ragbits .core .embeddings .base import Embedder , SparseVector , VectorSize
11
11
from ragbits .core .embeddings .sparse .base import SparseEmbedder
12
12
from ragbits .core .vector_stores .base import (
13
13
EmbeddingType ,
@@ -53,8 +53,8 @@ def __init__(
53
53
self ,
54
54
client : asyncpg .Pool ,
55
55
table_name : str ,
56
- vector_size : int ,
57
56
embedder : Embedder ,
57
+ vector_size : int | None = None ,
58
58
embedding_type : EmbeddingType = EmbeddingType .TEXT ,
59
59
distance_method : str | None = None ,
60
60
hnsw_params : dict | None = None ,
@@ -66,8 +66,8 @@ def __init__(
66
66
Args:
67
67
client: The pgVector database connection pool.
68
68
table_name: The name of the table.
69
- vector_size: The size of the vectors.
70
69
embedder: The embedder to use for converting entries to vectors.
70
+ vector_size: The size of the vectors. If None, will be determined automatically from the embedder.
71
71
embedding_type: Which part of the entry to embed, either text or image. The other part will be ignored.
72
72
distance_method: The distance method to use, default is "cosine" for dense vectors
73
73
and "sparsevec_l2" for sparse vectors.
@@ -84,7 +84,7 @@ def __init__(
84
84
85
85
if not re .match (r"^[a-zA-Z_][a-zA-Z0-9_]*$" , table_name ):
86
86
raise ValueError (f"Invalid table name: { table_name } " )
87
- if not isinstance (vector_size , int ) or vector_size <= 0 :
87
+ if vector_size is not None and ( not isinstance (vector_size , int ) or vector_size <= 0 ) :
88
88
raise ValueError ("Vector size must be a positive integer." )
89
89
90
90
if hnsw_params is None :
@@ -103,6 +103,7 @@ def __init__(
103
103
self ._client = client
104
104
self ._table_name = table_name
105
105
self ._vector_size = vector_size
106
+ self ._vector_size_info : VectorSize | None = None
106
107
self ._distance_method = distance_method
107
108
self ._hnsw_params = hnsw_params
108
109
@@ -113,6 +114,32 @@ def __reduce__(self) -> tuple:
113
114
# TODO: To be implemented. Required for Ray processing.
114
115
raise NotImplementedError
115
116
117
+ async def _get_vector_size_info (self ) -> VectorSize :
118
+ """
119
+ Get vector size information from the embedder if not already cached.
120
+
121
+ Returns:
122
+ VectorSize information including size and sparsity.
123
+ """
124
+ if self ._vector_size_info is None :
125
+ self ._vector_size_info = await self ._embedder .get_vector_size ()
126
+ # Update _vector_size for backward compatibility if it wasn't provided
127
+ if self ._vector_size is None :
128
+ self ._vector_size = self ._vector_size_info .size
129
+ return self ._vector_size_info
130
+
131
+ async def _get_vector_size (self ) -> int :
132
+ """
133
+ Get the vector size, either from the constructor parameter or from the embedder.
134
+
135
+ Returns:
136
+ The vector size as an integer.
137
+ """
138
+ if self ._vector_size is not None :
139
+ return self ._vector_size
140
+ vector_size_info = await self ._get_vector_size_info ()
141
+ return vector_size_info .size
142
+
116
143
def _vector_to_string (self , vector : list [float ] | SparseVector ) -> str :
117
144
"""
118
145
Converts a vector to a string representation.
@@ -124,8 +151,13 @@ def _vector_to_string(self, vector: list[float] | SparseVector) -> str:
124
151
str: The string representation of the vector.
125
152
"""
126
153
if isinstance (vector , SparseVector ):
154
+ # For sparse vectors, we need the vector size to be available
155
+ # This will be resolved when this method is called from async context
156
+ vector_size = self ._vector_size
157
+ if vector_size is None :
158
+ raise RuntimeError ("Vector size must be determined before converting sparse vectors to string" )
127
159
points_str = "," .join (f"{ i } :{ v } " for i , v in zip (vector .indices , vector .values , strict = False ))
128
- return f"{{{ points_str } }}/{ self . _vector_size } "
160
+ return f"{{{ points_str } }}/{ vector_size } "
129
161
return json .dumps (vector )
130
162
131
163
@staticmethod
@@ -234,23 +266,25 @@ async def create_table(self) -> None:
234
266
"""
235
267
Create a pgVector table with an HNSW index for given similarity.
236
268
"""
269
+ vector_size = await self ._get_vector_size ()
237
270
with trace (
238
271
table_name = self ._table_name ,
239
272
distance_method = self ._distance_method ,
240
- vector_size = self . _vector_size ,
273
+ vector_size = vector_size ,
241
274
hnsw_index_parameters = self ._hnsw_params ,
242
275
):
243
276
distance = DISTANCE_OPS [self ._distance_method ].function_name
244
277
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
245
278
# _table_name and has been validated in the class constructor, and it is a valid table name.
246
- # _vector_size has been validated in the class constructor, and it is a valid vector size.
279
+ # vector_size has been validated in the class constructor or obtained from embedder,
280
+ # and it is a valid vector size.
247
281
248
282
is_sparse = isinstance (self ._embedder , SparseEmbedder )
249
283
vector_func = "VECTOR" if not is_sparse else "SPARSEVEC"
250
284
251
285
create_table_query = f"""
252
286
CREATE TABLE { self ._table_name }
253
- (id UUID, text TEXT, image_bytes BYTEA, vector { vector_func } ({ self . _vector_size } ), metadata JSONB);
287
+ (id UUID, text TEXT, image_bytes BYTEA, vector { vector_func } ({ vector_size } ), metadata JSONB);
254
288
"""
255
289
# _hnsw_params has been validated in the class constructor, and it is valid dict[str,int].
256
290
create_index_query = f"""
@@ -283,6 +317,10 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
283
317
"""
284
318
if not entries :
285
319
return
320
+
321
+ # Ensure vector size is determined before processing
322
+ vector_size = await self ._get_vector_size ()
323
+
286
324
# _table_name has been validated in the class constructor, and it is a valid table name.
287
325
insert_query = f"""
288
326
INSERT INTO { self ._table_name } (id, text, image_bytes, vector, metadata)
@@ -291,7 +329,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
291
329
with trace (
292
330
table_name = self ._table_name ,
293
331
entries = entries ,
294
- vector_size = self . _vector_size ,
332
+ vector_size = vector_size ,
295
333
embedder = repr (self ._embedder ),
296
334
embedding_type = self ._embedding_type ,
297
335
):
@@ -359,11 +397,14 @@ async def retrieve(
359
397
"""
360
398
merged_options = (self .default_options | options ) if options else self .default_options
361
399
400
+ # Ensure vector size is determined before processing
401
+ vector_size = await self ._get_vector_size ()
402
+
362
403
with trace (
363
404
text = text ,
364
405
options = merged_options .dict (),
365
406
table_name = self ._table_name ,
366
- vector_size = self . _vector_size ,
407
+ vector_size = vector_size ,
367
408
distance_method = self ._distance_method ,
368
409
embedder = repr (self ._embedder ),
369
410
embedding_type = self ._embedding_type ,
0 commit comments