Skip to content

Commit 1d6fcf9

Browse files
authored
feat: resolve pgvector vector_size automatically (#588)
1 parent 0c36d1b commit 1d6fcf9

File tree

5 files changed

+97
-37
lines changed

5 files changed

+97
-37
lines changed

docs/how-to/vector_stores/use_pgVector_store.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ async with asyncpg.create_pool(dsn=DB) as pool:
6262

6363
The connection pool created with asyncpg.create_pool will be used to initialize an instance of PgVectorStore.
6464

65-
6665
```python
6766
import asyncpg
6867
from ragbits.core.vector_stores.pgvector import PgVectorStore
@@ -71,12 +70,13 @@ async def main() -> None:
7170
DB = "postgresql://ragbits_user:ragbits_password@localhost:5432/ragbits_db"
7271
async with asyncpg.create_pool(dsn=DB) as pool:
7372
embedder = LiteLLMEmbedder(model="text-embedding-3-small")
74-
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="test_table", vector_size=1536)
73+
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="test_table")
7574
```
7675

76+
7777
!!! note
78-
Ensure that the vector size is correctly configured when initializing PgVectorStore,
79-
as it must match the expected dimensions of the stored embeddings.
78+
PgVectorStore will automatically determine the vector dimensions from the embedder.
79+
If you prefer explicit control or need to override the automatic detection, you can provide the `vector_size` parameter to PgVectorStore initializer.
8080

8181
## pgVectorStore in Ragbits
8282
Example:
@@ -91,19 +91,19 @@ async def main() -> None:
9191
DB = "postgresql://ragbits_user:ragbits_password@localhost:5432/ragbits_db"
9292
async with asyncpg.create_pool(dsn=DB) as pool:
9393
embedder = LiteLLMEmbedder(model="text-embedding-3-small")
94-
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="test_table", vector_size=3)
95-
data = [VectorStoreEntry(id="test_id_1", key="test_key_1", vector=[0.1, 0.2, 0.3],
94+
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="test_table")
95+
data = [VectorStoreEntry(id="test_id_1", text="test text 1",
9696
metadata={"key1": "value1", "content": "test 1"}),
97-
VectorStoreEntry(id="test_id_2", key="test_key_2", vector=[0.4, 0.5, 0.6],
97+
VectorStoreEntry(id="test_id_2", text="test text 2",
9898
metadata={"key2": "value2", "content": "test 2"})]
9999

100100
await vector_store.store(data)
101101
all_entries = await vector_store.list()
102102
print("All entries ", all_entries)
103103
list_result = await vector_store.list({"content": "test 1"})
104104
print("Entries with {content: test 1}", list_result)
105-
retrieve_result = await vector_store.retrieve(vector=[0.39, 0.55, 0.6])
106-
print("Entries similar to [0.17, 0.23, 0.314] ", retrieve_result)
105+
retrieve_result = await vector_store.retrieve("similar test query")
106+
print("Entries similar to query", retrieve_result)
107107
await vector_store.remove(["test_id_1", "test_id_2"])
108108
after_remove = await vector_store.list()
109109
print("Entries after remove ", after_remove)

examples/document-search/pgvector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def main() -> None:
8484
embedder = LiteLLMEmbedder(
8585
model_name="text-embedding-3-small",
8686
)
87-
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="example", vector_size=1536)
87+
vector_store = PgVectorStore(embedder=embedder, client=pool, table_name="example")
8888
document_search = DocumentSearch(
8989
vector_store=vector_store,
9090
)

packages/ragbits-core/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
## Unreleased
44

5-
- feat: add get_vector_size method to all Embedders (#587)
5+
- Resolve vector_size by PgVectorStore automatically (#588)
6+
- Add get_vector_size method to all Embedders (#587)
67

78
## 0.19.1 (2025-05-27)
89

packages/ragbits-core/src/ragbits/core/vector_stores/pgvector.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic.json import pydantic_encoder
88

99
from ragbits.core.audit.traces import trace
10-
from ragbits.core.embeddings.base import Embedder, SparseVector
10+
from ragbits.core.embeddings.base import Embedder, SparseVector, VectorSize
1111
from ragbits.core.embeddings.sparse.base import SparseEmbedder
1212
from ragbits.core.vector_stores.base import (
1313
EmbeddingType,
@@ -53,8 +53,8 @@ def __init__(
5353
self,
5454
client: asyncpg.Pool,
5555
table_name: str,
56-
vector_size: int,
5756
embedder: Embedder,
57+
vector_size: int | None = None,
5858
embedding_type: EmbeddingType = EmbeddingType.TEXT,
5959
distance_method: str | None = None,
6060
hnsw_params: dict | None = None,
@@ -66,8 +66,8 @@ def __init__(
6666
Args:
6767
client: The pgVector database connection pool.
6868
table_name: The name of the table.
69-
vector_size: The size of the vectors.
7069
embedder: The embedder to use for converting entries to vectors.
70+
vector_size: The size of the vectors. If None, will be determined automatically from the embedder.
7171
embedding_type: Which part of the entry to embed, either text or image. The other part will be ignored.
7272
distance_method: The distance method to use, default is "cosine" for dense vectors
7373
and "sparsevec_l2" for sparse vectors.
@@ -84,7 +84,7 @@ def __init__(
8484

8585
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", table_name):
8686
raise ValueError(f"Invalid table name: {table_name}")
87-
if not isinstance(vector_size, int) or vector_size <= 0:
87+
if vector_size is not None and (not isinstance(vector_size, int) or vector_size <= 0):
8888
raise ValueError("Vector size must be a positive integer.")
8989

9090
if hnsw_params is None:
@@ -103,6 +103,7 @@ def __init__(
103103
self._client = client
104104
self._table_name = table_name
105105
self._vector_size = vector_size
106+
self._vector_size_info: VectorSize | None = None
106107
self._distance_method = distance_method
107108
self._hnsw_params = hnsw_params
108109

@@ -113,6 +114,32 @@ def __reduce__(self) -> tuple:
113114
# TODO: To be implemented. Required for Ray processing.
114115
raise NotImplementedError
115116

117+
async def _get_vector_size_info(self) -> VectorSize:
118+
"""
119+
Get vector size information from the embedder if not already cached.
120+
121+
Returns:
122+
VectorSize information including size and sparsity.
123+
"""
124+
if self._vector_size_info is None:
125+
self._vector_size_info = await self._embedder.get_vector_size()
126+
# Update _vector_size for backward compatibility if it wasn't provided
127+
if self._vector_size is None:
128+
self._vector_size = self._vector_size_info.size
129+
return self._vector_size_info
130+
131+
async def _get_vector_size(self) -> int:
132+
"""
133+
Get the vector size, either from the constructor parameter or from the embedder.
134+
135+
Returns:
136+
The vector size as an integer.
137+
"""
138+
if self._vector_size is not None:
139+
return self._vector_size
140+
vector_size_info = await self._get_vector_size_info()
141+
return vector_size_info.size
142+
116143
def _vector_to_string(self, vector: list[float] | SparseVector) -> str:
117144
"""
118145
Converts a vector to a string representation.
@@ -124,8 +151,13 @@ def _vector_to_string(self, vector: list[float] | SparseVector) -> str:
124151
str: The string representation of the vector.
125152
"""
126153
if isinstance(vector, SparseVector):
154+
# For sparse vectors, we need the vector size to be available
155+
# This will be resolved when this method is called from async context
156+
vector_size = self._vector_size
157+
if vector_size is None:
158+
raise RuntimeError("Vector size must be determined before converting sparse vectors to string")
127159
points_str = ",".join(f"{i}:{v}" for i, v in zip(vector.indices, vector.values, strict=False))
128-
return f"{{{points_str}}}/{self._vector_size}"
160+
return f"{{{points_str}}}/{vector_size}"
129161
return json.dumps(vector)
130162

131163
@staticmethod
@@ -234,23 +266,25 @@ async def create_table(self) -> None:
234266
"""
235267
Create a pgVector table with an HNSW index for given similarity.
236268
"""
269+
vector_size = await self._get_vector_size()
237270
with trace(
238271
table_name=self._table_name,
239272
distance_method=self._distance_method,
240-
vector_size=self._vector_size,
273+
vector_size=vector_size,
241274
hnsw_index_parameters=self._hnsw_params,
242275
):
243276
distance = DISTANCE_OPS[self._distance_method].function_name
244277
create_vector_extension = "CREATE EXTENSION IF NOT EXISTS vector;"
245278
# _table_name and has been validated in the class constructor, and it is a valid table name.
246-
# _vector_size has been validated in the class constructor, and it is a valid vector size.
279+
# vector_size has been validated in the class constructor or obtained from embedder,
280+
# and it is a valid vector size.
247281

248282
is_sparse = isinstance(self._embedder, SparseEmbedder)
249283
vector_func = "VECTOR" if not is_sparse else "SPARSEVEC"
250284

251285
create_table_query = f"""
252286
CREATE TABLE {self._table_name}
253-
(id UUID, text TEXT, image_bytes BYTEA, vector {vector_func}({self._vector_size}), metadata JSONB);
287+
(id UUID, text TEXT, image_bytes BYTEA, vector {vector_func}({vector_size}), metadata JSONB);
254288
"""
255289
# _hnsw_params has been validated in the class constructor, and it is valid dict[str,int].
256290
create_index_query = f"""
@@ -283,6 +317,10 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
283317
"""
284318
if not entries:
285319
return
320+
321+
# Ensure vector size is determined before processing
322+
vector_size = await self._get_vector_size()
323+
286324
# _table_name has been validated in the class constructor, and it is a valid table name.
287325
insert_query = f"""
288326
INSERT INTO {self._table_name} (id, text, image_bytes, vector, metadata)
@@ -291,7 +329,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None:
291329
with trace(
292330
table_name=self._table_name,
293331
entries=entries,
294-
vector_size=self._vector_size,
332+
vector_size=vector_size,
295333
embedder=repr(self._embedder),
296334
embedding_type=self._embedding_type,
297335
):
@@ -359,11 +397,14 @@ async def retrieve(
359397
"""
360398
merged_options = (self.default_options | options) if options else self.default_options
361399

400+
# Ensure vector size is determined before processing
401+
vector_size = await self._get_vector_size()
402+
362403
with trace(
363404
text=text,
364405
options=merged_options.dict(),
365406
table_name=self._table_name,
366-
vector_size=self._vector_size,
407+
vector_size=vector_size,
367408
distance_method=self._distance_method,
368409
embedder=repr(self._embedder),
369410
embedding_type=self._embedding_type,

packages/ragbits-core/tests/unit/vector_stores/test_pgvector.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -358,21 +358,39 @@ async def test_list_with_nested_where_clause(
358358
async def test_retrieve_with_where_clause_and_score_threshold(
359359
mock_pgvector_store: PgVectorStore, mock_db_pool: tuple[MagicMock, AsyncMock]
360360
) -> None:
361-
data = DATA_JSON_EXAMPLE
362-
query = "SQL RETRIEVE QUERY"
363361
_, mock_conn = mock_db_pool
364-
where_clause: dict[str, str | int | float | bool | dict[Any, Any]] = {"key1": "value1"}
362+
mock_conn.fetch.return_value = []
363+
where = cast(WhereQuery, {"id": "test_id"})
364+
await mock_pgvector_store.retrieve(text="test", options=VectorStoreOptions(score_threshold=0.1, where=where))
365+
mock_conn.fetch.assert_called_once()
366+
calls = mock_conn.fetch.mock_calls
367+
assert calls[0].args[0].startswith("SELECT *, vector <=> $1 as distance, 1 - (vector <=> $1) as score FROM")
368+
assert calls[0].args[0].endswith("WHERE score >= $2 AND metadata @> $3 ORDER BY distance LIMIT $4;")
365369

366-
with patch.object(mock_pgvector_store, "_create_retrieve_query") as mock_create_retrieve_query:
367-
mock_conn.fetch = AsyncMock(return_value=data)
368-
mock_create_retrieve_query.return_value = (query, [["[0.1, 0.2, 0.3]", 0.1, 1]])
369-
results = await mock_pgvector_store.retrieve(
370-
text="some_text", options=VectorStoreOptions(where=where_clause, score_threshold=0.5)
371-
)
372-
mock_create_retrieve_query.assert_called_once()
373-
mock_conn.fetch.assert_called_once()
374-
assert len(results) == 2
375-
assert isinstance(results[0], VectorStoreResult)
376-
assert isinstance(results[1], VectorStoreResult)
377-
assert results[0].entry.id == UUID("8c7d6b27-4ef1-537c-ad7c-676edb8bc8a8")
378-
assert results[1].entry.id == UUID("9c7d6b27-4ef1-537c-ad7c-676edb8bc8a8")
370+
371+
@pytest.mark.asyncio
372+
async def test_auto_vector_size_determination(mock_db_pool: tuple[MagicMock, AsyncMock]) -> None:
373+
"""Test that PgVectorStore can determine vector size automatically from embedder."""
374+
mock_pool, _mock_conn = mock_db_pool
375+
mock_embedder = AsyncMock()
376+
377+
# Mock the get_vector_size method to return a VectorSize with size 5
378+
from ragbits.core.embeddings.base import VectorSize
379+
380+
mock_embedder.get_vector_size.return_value = VectorSize(size=5, is_sparse=False)
381+
382+
# Create PgVectorStore without providing vector_size
383+
store = PgVectorStore(client=mock_pool, table_name=TEST_TABLE_NAME, embedder=mock_embedder)
384+
385+
# The vector size should be None initially
386+
assert store._vector_size is None
387+
388+
# When we call _get_vector_size(), it should determine the size from embedder
389+
vector_size = await store._get_vector_size()
390+
assert vector_size == 5
391+
392+
# Now _vector_size should be cached
393+
assert store._vector_size == 5
394+
395+
# Verify the embedder's get_vector_size was called
396+
mock_embedder.get_vector_size.assert_called_once()

0 commit comments

Comments
 (0)