Skip to content

Commit 8116e7b

Browse files
authored
Metadata & Applying Filters (#103)
* feat: add ability to add metadata field * feat: add applying filters on queries for qdrant * feat: applying filters for other providers * fix: pydantic error * chore: fix linting
1 parent 32bde47 commit 8116e7b

File tree

10 files changed

+45
-10
lines changed

10 files changed

+45
-10
lines changed

models/file.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def suffix(self) -> str:
3737
class File(BaseModel):
3838
url: str
3939
name: str | None = None
40+
metadata: dict | None = None
4041

4142
@property
4243
def type(self) -> FileType | None:

models/query.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import List, Optional
2-
31
from pydantic import BaseModel
2+
from typing import List, Optional, Union
43

54
from models.document import BaseDocumentChunk
65
from models.ingest import EncoderConfig
76
from models.vector_database import VectorDatabase
7+
from qdrant_client.http.models import Filter as QdrantFilter
8+
9+
10+
Filter = Union[QdrantFilter, dict]
811

912

1013
class RequestPayload(BaseModel):
@@ -15,6 +18,7 @@ class RequestPayload(BaseModel):
1518
session_id: Optional[str] = None
1619
interpreter_mode: Optional[bool] = False
1720
exclude_fields: List[str] = None
21+
filter: Optional[Filter] = None
1822

1923

2024
class ResponseData(BaseModel):

service/embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ async def generate_chunks(
156156
) -> List[BaseDocumentChunk]:
157157
doc_chunks = []
158158
for file in tqdm(self.files, desc="Generating chunks"):
159+
file_metadata = file.metadata or {}
159160
logger.info(f"Splitting method: {config.splitter.name}")
160161
try:
161162
chunks = []
@@ -168,7 +169,10 @@ async def generate_chunks(
168169
chunk_data = {
169170
"content": element.get("text"),
170171
"metadata": self._sanitize_metadata(
171-
element.get("metadata")
172+
{
173+
**file_metadata,
174+
**element.get("metadata"),
175+
}
172176
),
173177
}
174178
chunks.append(chunk_data)

service/router.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ def create_route_layer() -> RouteLayer:
4040
async def get_documents(
4141
*, vector_service: BaseVectorDatabase, payload: RequestPayload
4242
) -> list[BaseDocumentChunk]:
43-
chunks = await vector_service.query(input=payload.input, top_k=5)
43+
chunks = await vector_service.query(
44+
input=payload.input, filter=payload.filter, top_k=5
45+
)
4446
# filter out documents with empty content
4547
chunks = [chunk for chunk in chunks if chunk.content.strip()]
4648
if not len(chunks):

vectordbs/astra.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tqdm import tqdm
66

77
from models.document import BaseDocumentChunk
8+
from models.query import Filter
89
from vectordbs.base import BaseVectorDatabase
910

1011

@@ -54,12 +55,13 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
5455
for i in range(0, len(documents), 5):
5556
self.collection.insert_many(documents=documents[i : i + 5])
5657

57-
async def query(self, input: str, top_k: int = 4) -> List:
58+
async def query(self, input: str, filter: Filter = None, top_k: int = 4) -> List:
5859
vectors = await self._generate_vectors(input=input)
5960
results = self.collection.vector_find(
6061
vector=vectors[0],
6162
limit=top_k,
6263
fields={"text", "page_number", "source", "document_id"},
64+
filter=filter,
6365
)
6466
return [
6567
BaseDocumentChunk(

vectordbs/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from models.delete import DeleteResponse
99
from models.document import BaseDocumentChunk
10+
from models.query import Filter
1011
from utils.logger import logger
1112

1213

@@ -24,7 +25,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]):
2425
pass
2526

2627
@abstractmethod
27-
async def query(self, input: str, top_k: int = 25) -> List[BaseDocumentChunk]:
28+
async def query(
29+
self, input: str, filter: Filter, top_k: int = 25
30+
) -> List[BaseDocumentChunk]:
2831
pass
2932

3033
@abstractmethod

vectordbs/pgvector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from qdrant_client.http import models as rest
88
from models.delete import DeleteResponse
99
from models.document import BaseDocumentChunk
10+
from models.query import Filter
1011
from vectordbs.base import BaseVectorDatabase
1112

1213
MAX_QUERY_TOP_K = 5
@@ -58,14 +59,17 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
5859
self.collection.upsert(records)
5960
self.collection.create_index()
6061

61-
async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
62+
async def query(
63+
self, input: str, filter: Filter = None, top_k: int = MAX_QUERY_TOP_K
64+
) -> List:
6265
vectors = await self._generate_vectors(input=input)
6366

6467
results = self.collection.query(
6568
data=vectors[0],
6669
limit=top_k,
6770
include_metadata=True,
6871
include_value=False,
72+
filters=filter.model_dump() if filter else {},
6973
)
7074

7175
chunks = []

vectordbs/pinecone.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from models.delete import DeleteResponse
88
from models.document import BaseDocumentChunk
9+
from models.query import Filter
910
from utils.logger import logger
1011
from vectordbs.base import BaseVectorDatabase
1112

@@ -52,7 +53,11 @@ async def upsert(self, chunks: List[BaseDocumentChunk], batch_size: int = 100):
5253
raise
5354

5455
async def query(
55-
self, input: str, top_k: int = 25, include_metadata: bool = True
56+
self,
57+
input: str,
58+
filter: Filter = None,
59+
top_k: int = 25,
60+
include_metadata: bool = True,
5661
) -> list[BaseDocumentChunk]:
5762
if self.index is None:
5863
raise ValueError(f"Pinecone index {self.index_name} is not initialized.")
@@ -61,6 +66,7 @@ async def query(
6166
vector=query_vectors[0],
6267
top_k=top_k,
6368
include_metadata=include_metadata,
69+
filter=filter,
6470
)
6571
chunks = []
6672
if results.get("matches"):

vectordbs/qdrant.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from models.delete import DeleteResponse
99
from models.document import BaseDocumentChunk
10+
from models.query import Filter
1011
from vectordbs.base import BaseVectorDatabase
1112

1213
MAX_QUERY_TOP_K = 5
@@ -69,11 +70,14 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
6970

7071
self.client.upsert(collection_name=self.index_name, wait=True, points=points)
7172

72-
async def query(self, input: str, top_k: int = MAX_QUERY_TOP_K) -> List:
73+
async def query(
74+
self, input: str, filter: Filter, top_k: int = MAX_QUERY_TOP_K
75+
) -> List:
7376
vectors = await self._generate_vectors(input=input)
7477
search_result = self.client.search(
7578
collection_name=self.index_name,
7679
query_vector=("content", vectors[0]),
80+
query_filter=filter,
7781
limit=top_k,
7882
with_payload=True,
7983
)

vectordbs/weaviate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from utils.logger import logger
1111
from vectordbs.base import BaseVectorDatabase
1212

13+
from models.query import Filter
14+
1315

1416
class WeaviateService(BaseVectorDatabase):
1517
def __init__(
@@ -72,7 +74,9 @@ async def upsert(self, chunks: List[BaseDocumentChunk]) -> None:
7274
batch.add_data_object(**vector_data)
7375
batch.flush()
7476

75-
async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]:
77+
async def query(
78+
self, input: str, filter: Filter = {}, top_k: int = 25
79+
) -> list[BaseDocumentChunk]:
7680
vectors = await self._generate_vectors(input=input)
7781
vector = {"vector": vectors[0]}
7882

@@ -84,6 +88,7 @@ async def query(self, input: str, top_k: int = 25) -> list[BaseDocumentChunk]:
8488
)
8589
.with_near_vector(vector)
8690
.with_limit(top_k)
91+
.with_where(filter)
8792
.do()
8893
)
8994
if "data" not in response:

0 commit comments

Comments
 (0)