|
27 | 27 | from pinecone.core.openapi.data.api.data_plane_api import DataPlaneApi
|
28 | 28 | from ..utils import setup_openapi_client, parse_non_empty_args
|
29 | 29 | from .vector_factory import VectorFactory
|
| 30 | +from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults |
| 31 | +from multiprocessing.pool import ApplyResult |
30 | 32 |
|
31 | 33 | __all__ = [
|
32 | 34 | "Index",
|
@@ -361,7 +363,7 @@ def query(
|
361 | 363 | Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
|
362 | 364 | ] = None,
|
363 | 365 | **kwargs,
|
364 |
| - ) -> QueryResponse: |
| 366 | + ) -> Union[QueryResponse, ApplyResult[QueryResponse]]: |
365 | 367 | """
|
366 | 368 | The Query operation searches a namespace, using a query vector.
|
367 | 369 | It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
|
@@ -403,6 +405,39 @@ def query(
|
403 | 405 | and namespace name.
|
404 | 406 | """
|
405 | 407 |
|
| 408 | + response = self._query( |
| 409 | + *args, |
| 410 | + top_k=top_k, |
| 411 | + vector=vector, |
| 412 | + id=id, |
| 413 | + namespace=namespace, |
| 414 | + filter=filter, |
| 415 | + include_values=include_values, |
| 416 | + include_metadata=include_metadata, |
| 417 | + sparse_vector=sparse_vector, |
| 418 | + **kwargs, |
| 419 | + ) |
| 420 | + |
| 421 | + if kwargs.get("async_req", False): |
| 422 | + return response |
| 423 | + else: |
| 424 | + return parse_query_response(response) |
| 425 | + |
| 426 | + def _query( |
| 427 | + self, |
| 428 | + *args, |
| 429 | + top_k: int, |
| 430 | + vector: Optional[List[float]] = None, |
| 431 | + id: Optional[str] = None, |
| 432 | + namespace: Optional[str] = None, |
| 433 | + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, |
| 434 | + include_values: Optional[bool] = None, |
| 435 | + include_metadata: Optional[bool] = None, |
| 436 | + sparse_vector: Optional[ |
| 437 | + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] |
| 438 | + ] = None, |
| 439 | + **kwargs, |
| 440 | + ) -> QueryResponse: |
406 | 441 | if len(args) > 0:
|
407 | 442 | raise ValueError(
|
408 | 443 | "The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
|
@@ -435,7 +470,70 @@ def query(
|
435 | 470 | ),
|
436 | 471 | **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
|
437 | 472 | )
|
438 |
| - return parse_query_response(response) |
| 473 | + return response |
| 474 | + |
| 475 | + def query_namespaces( |
| 476 | + self, |
| 477 | + vector: List[float], |
| 478 | + namespaces: List[str], |
| 479 | + top_k: Optional[int] = None, |
| 480 | + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, |
| 481 | + include_values: Optional[bool] = None, |
| 482 | + include_metadata: Optional[bool] = None, |
| 483 | + sparse_vector: Optional[ |
| 484 | + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] |
| 485 | + ] = None, |
| 486 | + show_progress: Optional[bool] = True, |
| 487 | + **kwargs, |
| 488 | + ) -> QueryNamespacesResults: |
| 489 | + if len(namespaces) == 0: |
| 490 | + raise ValueError("At least one namespace must be specified") |
| 491 | + if len(vector) == 0: |
| 492 | + raise ValueError("Query vector must not be empty") |
| 493 | + |
| 494 | + # The caller may only want the top_k=1 result across all queries, |
| 495 | + # but we need to get at least 2 results from each query in order to |
| 496 | + # aggregate them correctly. So we'll temporarily set topK to 2 for the |
| 497 | + # subqueries, and then we'll take the topK=1 results from the aggregated |
| 498 | + # results. |
| 499 | + overall_topk = top_k if top_k is not None else 10 |
| 500 | + aggregator = QueryResultsAggregator(top_k=overall_topk) |
| 501 | + subquery_topk = overall_topk if overall_topk > 2 else 2 |
| 502 | + |
| 503 | + target_namespaces = set(namespaces) # dedup namespaces |
| 504 | + async_results = [ |
| 505 | + self.query( |
| 506 | + vector=vector, |
| 507 | + namespace=ns, |
| 508 | + top_k=subquery_topk, |
| 509 | + filter=filter, |
| 510 | + include_values=include_values, |
| 511 | + include_metadata=include_metadata, |
| 512 | + sparse_vector=sparse_vector, |
| 513 | + async_req=True, |
| 514 | + **kwargs, |
| 515 | + ) |
| 516 | + for ns in target_namespaces |
| 517 | + ] |
| 518 | + |
| 519 | + for result in async_results: |
| 520 | + response = result.get() |
| 521 | + aggregator.add_results(response) |
| 522 | + |
| 523 | + final_results = aggregator.get_results() |
| 524 | + return final_results |
| 525 | + |
| 526 | + # with tqdm( |
| 527 | + # total=len(query_tasks), disable=not show_progress, desc="Querying namespaces" |
| 528 | + # ) as pbar: |
| 529 | + # for query_task in asyncio.as_completed(query_tasks): |
| 530 | + # response = await query_task |
| 531 | + # pbar.update(1) |
| 532 | + # async with aggregator_lock: |
| 533 | + # aggregator.add_results(response) |
| 534 | + |
| 535 | + # final_results = aggregator.get_results() |
| 536 | + # return final_results |
439 | 537 |
|
440 | 538 | @validate_and_convert_errors
|
441 | 539 | def update(
|
|
0 commit comments