Skip to content

Commit c618fff

Browse files
committed
Fix sparse query types
1 parent df0e8e4 commit c618fff

File tree

2 files changed

+7
-28
lines changed

2 files changed

+7
-28
lines changed

pinecone/data/request_factory.py

+4-26
Original file line numberDiff line numberDiff line change
@@ -25,28 +25,6 @@
2525
logger = logging.getLogger(__name__)
2626

2727

28-
def parse_sparse_values_arg(
29-
sparse_values: Optional[Union[SparseValues, SparseVectorTypedDict]],
30-
) -> Optional[SparseValues]:
31-
if sparse_values is None:
32-
return None
33-
34-
if isinstance(sparse_values, SparseValues):
35-
return sparse_values
36-
37-
if (
38-
not isinstance(sparse_values, dict)
39-
or "indices" not in sparse_values
40-
or "values" not in sparse_values
41-
):
42-
raise ValueError(
43-
"Invalid sparse values argument. Expected a dict of: {'indices': List[int], 'values': List[float]}."
44-
f"Received: {sparse_values}"
45-
)
46-
47-
return SparseValues(indices=sparse_values["indices"], values=sparse_values["values"])
48-
49-
5028
def non_openapi_kwargs(kwargs):
5129
return {k: v for k, v in kwargs.items() if k not in OPENAPI_ENDPOINT_PARAMS}
5230

@@ -67,7 +45,7 @@ def query_request(
6745
if vector is not None and id is not None:
6846
raise ValueError("Cannot specify both `id` and `vector`")
6947

70-
sparse_vector = SparseValuesFactory.build(sparse_vector)
48+
sparse_vector_normalized = SparseValuesFactory.build(sparse_vector)
7149
args_dict = parse_non_empty_args(
7250
[
7351
("vector", vector),
@@ -78,7 +56,7 @@ def query_request(
7856
("filter", filter),
7957
("include_values", include_values),
8058
("include_metadata", include_metadata),
81-
("sparse_vector", sparse_vector),
59+
("sparse_vector", sparse_vector_normalized),
8260
]
8361
)
8462

@@ -131,13 +109,13 @@ def update_request(
131109
**kwargs,
132110
) -> UpdateRequest:
133111
_check_type = kwargs.pop("_check_type", False)
134-
sparse_values = parse_sparse_values_arg(sparse_values)
112+
sparse_values_normalized = SparseValuesFactory.build(sparse_values)
135113
args_dict = parse_non_empty_args(
136114
[
137115
("values", values),
138116
("set_metadata", set_metadata),
139117
("namespace", namespace),
140-
("sparse_values", sparse_values),
118+
("sparse_values", sparse_values_normalized),
141119
]
142120
)
143121

pinecone/data/sparse_values_factory.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Mapping
2-
from typing import Union, Dict, Optional
2+
from typing import Union, Optional
33

44
from ..utils import convert_to_list
55

@@ -10,6 +10,7 @@
1010
)
1111

1212
from .dataclasses import SparseValues
13+
from .types import SparseVectorTypedDict
1314
from pinecone.core.openapi.db_data.models import SparseValues as OpenApiSparseValues
1415

1516

@@ -18,7 +19,7 @@ class SparseValuesFactory:
1819

1920
@staticmethod
2021
def build(
21-
input: Union[Dict, Optional[SparseValues], OpenApiSparseValues],
22+
input: Optional[Union[SparseValues, OpenApiSparseValues, SparseVectorTypedDict]],
2223
) -> Optional[OpenApiSparseValues]:
2324
if input is None:
2425
return input

0 commit comments

Comments
 (0)