Skip to content

Commit 31eacbd

Browse files
committed
WIP on threadpool impl of query_namespaces
1 parent 247a329 commit 31eacbd

File tree

4 files changed

+102
-4
lines changed

4 files changed

+102
-4
lines changed

pinecone/data/index.py

+100-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from pinecone.core.openapi.data.api.data_plane_api import DataPlaneApi
2828
from ..utils import setup_openapi_client, parse_non_empty_args
2929
from .vector_factory import VectorFactory
30+
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
31+
from multiprocessing.pool import ApplyResult
3032

3133
__all__ = [
3234
"Index",
@@ -361,7 +363,7 @@ def query(
361363
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
362364
] = None,
363365
**kwargs,
364-
) -> QueryResponse:
366+
) -> Union[QueryResponse, ApplyResult[QueryResponse]]:
365367
"""
366368
The Query operation searches a namespace, using a query vector.
367369
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
@@ -403,6 +405,39 @@ def query(
403405
and namespace name.
404406
"""
405407

408+
response = self._query(
409+
*args,
410+
top_k=top_k,
411+
vector=vector,
412+
id=id,
413+
namespace=namespace,
414+
filter=filter,
415+
include_values=include_values,
416+
include_metadata=include_metadata,
417+
sparse_vector=sparse_vector,
418+
**kwargs,
419+
)
420+
421+
if kwargs.get("async_req", False):
422+
return response
423+
else:
424+
return parse_query_response(response)
425+
426+
def _query(
427+
self,
428+
*args,
429+
top_k: int,
430+
vector: Optional[List[float]] = None,
431+
id: Optional[str] = None,
432+
namespace: Optional[str] = None,
433+
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
434+
include_values: Optional[bool] = None,
435+
include_metadata: Optional[bool] = None,
436+
sparse_vector: Optional[
437+
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
438+
] = None,
439+
**kwargs,
440+
) -> QueryResponse:
406441
if len(args) > 0:
407442
raise ValueError(
408443
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
@@ -435,7 +470,70 @@ def query(
435470
),
436471
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
437472
)
438-
return parse_query_response(response)
473+
return response
474+
475+
def query_namespaces(
476+
self,
477+
vector: List[float],
478+
namespaces: List[str],
479+
top_k: Optional[int] = None,
480+
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
481+
include_values: Optional[bool] = None,
482+
include_metadata: Optional[bool] = None,
483+
sparse_vector: Optional[
484+
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
485+
] = None,
486+
show_progress: Optional[bool] = True,
487+
**kwargs,
488+
) -> QueryNamespacesResults:
489+
if len(namespaces) == 0:
490+
raise ValueError("At least one namespace must be specified")
491+
if len(vector) == 0:
492+
raise ValueError("Query vector must not be empty")
493+
494+
# The caller may only want the top_k=1 result across all queries,
495+
# but we need to get at least 2 results from each query in order to
496+
# aggregate them correctly. So we'll temporarily set topK to 2 for the
497+
# subqueries, and then we'll take the topK=1 results from the aggregated
498+
# results.
499+
overall_topk = top_k if top_k is not None else 10
500+
aggregator = QueryResultsAggregator(top_k=overall_topk)
501+
subquery_topk = overall_topk if overall_topk > 2 else 2
502+
503+
target_namespaces = set(namespaces) # dedup namespaces
504+
async_results = [
505+
self.query(
506+
vector=vector,
507+
namespace=ns,
508+
top_k=subquery_topk,
509+
filter=filter,
510+
include_values=include_values,
511+
include_metadata=include_metadata,
512+
sparse_vector=sparse_vector,
513+
async_req=True,
514+
**kwargs,
515+
)
516+
for ns in target_namespaces
517+
]
518+
519+
for result in async_results:
520+
response = result.get()
521+
aggregator.add_results(response)
522+
523+
final_results = aggregator.get_results()
524+
return final_results
525+
526+
# with tqdm(
527+
# total=len(query_tasks), disable=not show_progress, desc="Querying namespaces"
528+
# ) as pbar:
529+
# for query_task in asyncio.as_completed(query_tasks):
530+
# response = await query_task
531+
# pbar.update(1)
532+
# async with aggregator_lock:
533+
# aggregator.add_results(response)
534+
535+
# final_results = aggregator.get_results()
536+
# return final_results
439537

440538
@validate_and_convert_errors
441539
def update(

pinecone/grpc/index_grpc_asyncio.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
parse_sparse_values_arg,
4343
)
4444
from .vector_factory_grpc import VectorFactoryGRPC
45-
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
45+
from ..data.query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
4646

4747

4848
class GRPCIndexAsyncio(GRPCIndexBase):

tests/unit_grpc/test_query_results_aggregator.py tests/unit/test_query_results_aggregator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pinecone.grpc.query_results_aggregator import (
1+
from pinecone.data.query_results_aggregator import (
22
QueryResultsAggregator,
33
QueryResultsAggregatorInvalidTopKError,
44
QueryResultsAggregregatorNotEnoughResultsError,

0 commit comments

Comments
 (0)