Skip to content

Commit 980ac3b

Browse files
authored
[refactor] Introduce PluginAware utility class (#443)
1 parent 6e81b12 commit 980ac3b

File tree

10 files changed

+43
-135
lines changed

10 files changed

+43
-135
lines changed

pinecone/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import logging
2222

23-
# Raise an exception if the user is attempting to use the SDK with deprecated plugins
24-
# installed in their project.
23+
# Raise an exception if the user is attempting to use the SDK with
24+
# deprecated plugins installed in their project.
2525
check_for_deprecated_plugins()
2626

2727
# Silence annoying log messages from the plugin interface

pinecone/control/pinecone.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pinecone.openapi_support.api_client import ApiClient
1212

1313

14-
from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
14+
from pinecone.utils import normalize_host, setup_openapi_client, PluginAware
1515
from pinecone.core.openapi.db_control import API_VERSION
1616
from pinecone.models import (
1717
ServerlessSpec,
@@ -38,13 +38,11 @@
3838
from .types import CreateIndexForModelEmbedTypedDict
3939
from .request_factory import PineconeDBControlRequestFactory
4040

41-
from pinecone_plugin_interface import load_and_install as install_plugins
42-
4341
logger = logging.getLogger(__name__)
4442
""" @private """
4543

4644

47-
class Pinecone(PineconeDBControlInterface):
45+
class Pinecone(PineconeDBControlInterface, PluginAware):
4846
"""
4947
A client for interacting with Pinecone's vector database.
5048
@@ -113,21 +111,6 @@ def inference(self):
113111
self._inference = _Inference(config=self.config, openapi_config=self.openapi_config)
114112
return self._inference
115113

116-
def load_plugins(self):
117-
"""@private"""
118-
try:
119-
# I don't expect this to ever throw, but wrapping this in a
120-
# try block just in case to make sure a bad plugin doesn't
121-
# halt client initialization.
122-
openapi_client_builder = build_plugin_setup_client(
123-
config=self.config,
124-
openapi_config=self.openapi_config,
125-
pool_threads=self.pool_threads,
126-
)
127-
install_plugins(self, openapi_client_builder)
128-
except Exception as e:
129-
logger.error(f"Error loading plugins: {e}")
130-
131114
def create_index(
132115
self,
133116
name: str,

pinecone/control/pinecone_asyncio.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pinecone.core.openapi.db_control.api.manage_indexes_api import AsyncioManageIndexesApi
1010
from pinecone.openapi_support import AsyncioApiClient
1111

12-
from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
12+
from pinecone.utils import normalize_host, setup_openapi_client
1313
from pinecone.core.openapi.db_control import API_VERSION
1414
from pinecone.models import (
1515
ServerlessSpec,
@@ -36,8 +36,6 @@
3636
from .request_factory import PineconeDBControlRequestFactory
3737
from .pinecone_interface_asyncio import PineconeAsyncioDBControlInterface
3838

39-
from pinecone_plugin_interface import load_and_install as install_plugins
40-
4139
logger = logging.getLogger(__name__)
4240
""" @private """
4341

@@ -104,8 +102,6 @@ def __init__(
104102
self.index_host_store = IndexHostStore()
105103
""" @private """
106104

107-
self.load_plugins()
108-
109105
async def __aenter__(self):
110106
return self
111107

@@ -122,21 +118,6 @@ def inference(self):
122118
self._inference = _AsyncioInference(api_client=self.index_api.api_client)
123119
return self._inference
124120

125-
def load_plugins(self):
126-
"""@private"""
127-
try:
128-
# I don't expect this to ever throw, but wrapping this in a
129-
# try block just in case to make sure a bad plugin doesn't
130-
# halt client initialization.
131-
openapi_client_builder = build_plugin_setup_client(
132-
config=self.config,
133-
openapi_config=self.openapi_config,
134-
pool_threads=self.pool_threads,
135-
)
136-
install_plugins(self, openapi_client_builder)
137-
except Exception as e:
138-
logger.error(f"Error loading plugins: {e}")
139-
140121
async def create_index(
141122
self,
142123
name: str,

pinecone/data/features/inference/inference.py

+2-19
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from pinecone.core.openapi.inference.apis import InferenceApi
66
from pinecone.core.openapi.inference.models import EmbeddingsList, RerankResult
77
from pinecone.core.openapi.inference import API_VERSION
8-
from pinecone.utils import setup_openapi_client, build_plugin_setup_client
8+
from pinecone.utils import setup_openapi_client, PluginAware
99

10-
from pinecone_plugin_interface import load_and_install as install_plugins
1110

1211
from .inference_request_builder import (
1312
InferenceRequestBuilder,
@@ -18,7 +17,7 @@
1817
logger = logging.getLogger(__name__)
1918

2019

21-
class Inference:
20+
class Inference(PluginAware):
2221
"""
2322
The `Inference` class configures and uses the Pinecone Inference API to generate embeddings and
2423
rank documents.
@@ -43,24 +42,8 @@ def __init__(self, config, openapi_config, **kwargs):
4342
pool_threads=kwargs.get("pool_threads", 1),
4443
api_version=API_VERSION,
4544
)
46-
4745
self.load_plugins()
4846

49-
def load_plugins(self):
50-
"""@private"""
51-
try:
52-
# I don't expect this to ever throw, but wrapping this in a
53-
# try block just in case to make sure a bad plugin doesn't
54-
# halt client initialization.
55-
openapi_client_builder = build_plugin_setup_client(
56-
config=self.config,
57-
openapi_config=self.openapi_config,
58-
pool_threads=self.pool_threads,
59-
)
60-
install_plugins(self, openapi_client_builder)
61-
except Exception as e:
62-
logger.error(f"Error loading plugins: {e}")
63-
6447
def embed(
6548
self,
6649
model: Union[EmbedModelEnum, str],

pinecone/data/index.py

+7-23
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,15 @@
3333
from ..utils import (
3434
setup_openapi_client,
3535
parse_non_empty_args,
36-
build_plugin_setup_client,
3736
validate_and_convert_errors,
37+
PluginAware,
3838
)
3939
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
4040
from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS
4141

4242
from multiprocessing.pool import ApplyResult
4343
from concurrent.futures import as_completed
4444

45-
from pinecone_plugin_interface import load_and_install as install_plugins
4645

4746
logger = logging.getLogger(__name__)
4847

@@ -52,7 +51,7 @@ def parse_query_response(response: QueryResponse):
5251
return response
5352

5453

55-
class Index(IndexInterface, ImportFeatureMixin):
54+
class Index(IndexInterface, ImportFeatureMixin, PluginAware):
5655
"""
5756
A client for interacting with a Pinecone index via REST API.
5857
For improved performance, use the Pinecone GRPC index client.
@@ -70,17 +69,17 @@ def __init__(
7069
self.config = ConfigBuilder.build(
7170
api_key=api_key, host=host, additional_headers=additional_headers, **kwargs
7271
)
73-
self._openapi_config = ConfigBuilder.build_openapi_config(self.config, openapi_config)
74-
self._pool_threads = pool_threads
72+
self.openapi_config = ConfigBuilder.build_openapi_config(self.config, openapi_config)
73+
self.pool_threads = pool_threads
7574

7675
if kwargs.get("connection_pool_maxsize", None):
77-
self._openapi_config.connection_pool_maxsize = kwargs.get("connection_pool_maxsize")
76+
self.openapi_config.connection_pool_maxsize = kwargs.get("connection_pool_maxsize")
7877

7978
self._vector_api = setup_openapi_client(
8079
api_client_klass=ApiClient,
8180
api_klass=VectorOperationsApi,
8281
config=self.config,
83-
openapi_config=self._openapi_config,
82+
openapi_config=self.openapi_config,
8483
pool_threads=pool_threads,
8584
api_version=API_VERSION,
8685
)
@@ -90,22 +89,7 @@ def __init__(
9089
# Pass the same api_client to the ImportFeatureMixin
9190
super().__init__(api_client=self._api_client)
9291

93-
self._load_plugins()
94-
95-
def _load_plugins(self):
96-
"""@private"""
97-
try:
98-
# I don't expect this to ever throw, but wrapping this in a
99-
# try block just in case to make sure a bad plugin doesn't
100-
# halt client initialization.
101-
openapi_client_builder = build_plugin_setup_client(
102-
config=self.config,
103-
openapi_config=self._openapi_config,
104-
pool_threads=self._pool_threads,
105-
)
106-
install_plugins(self, openapi_client_builder)
107-
except Exception as e:
108-
logger.error(f"Error loading plugins in Index: {e}")
92+
self.load_plugins()
10993

11094
def _openapi_kwargs(self, kwargs):
11195
return {k: v for k, v in kwargs.items() if k in OPENAPI_ENDPOINT_PARAMS}

pinecone/data/index_asyncio.py

+2-24
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,7 @@
2323
SearchRecordsResponse,
2424
)
2525

26-
from ..utils import (
27-
setup_openapi_client,
28-
parse_non_empty_args,
29-
build_plugin_setup_client,
30-
validate_and_convert_errors,
31-
)
26+
from ..utils import setup_openapi_client, parse_non_empty_args, validate_and_convert_errors
3227
from .types import (
3328
SparseVectorTypedDict,
3429
VectorTypedDict,
@@ -47,7 +42,7 @@
4742
from .vector_factory import VectorFactory
4843
from .query_results_aggregator import QueryNamespacesResults
4944
from .features.bulk_import import ImportFeatureMixinAsyncio
50-
from pinecone_plugin_interface import load_and_install as install_plugins
45+
5146

5247
logger = logging.getLogger(__name__)
5348

@@ -107,23 +102,6 @@ def __init__(
107102
# This is important for async context management to work correctly
108103
super().__init__(api_client=self._api_client)
109104

110-
self._load_plugins()
111-
112-
def _load_plugins(self):
113-
"""@private"""
114-
try:
115-
# I don't expect this to ever throw, but wrapping this in a
116-
# try block just in case to make sure a bad plugin doesn't
117-
# halt client initialization.
118-
openapi_client_builder = build_plugin_setup_client(
119-
config=self.config,
120-
openapi_config=self._openapi_config,
121-
pool_threads=self._pool_threads,
122-
)
123-
install_plugins(self, openapi_client_builder)
124-
except Exception as e:
125-
logger.error(f"Error loading plugins in Index: {e}")
126-
127105
async def __aenter__(self):
128106
return self
129107

pinecone/grpc/base.py

-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from abc import ABC, abstractmethod
22
from typing import Optional
33

4-
import logging
54
import grpc
65
from grpc._channel import Channel
76

@@ -12,10 +11,6 @@
1211
from .grpc_runner import GrpcRunner
1312
from concurrent.futures import ThreadPoolExecutor
1413

15-
from pinecone_plugin_interface import load_and_install as install_plugins
16-
17-
_logger = logging.getLogger(__name__)
18-
1914

2015
class GRPCIndexBase(ABC):
2116
"""
@@ -48,19 +43,6 @@ def __init__(
4843
self._channel = channel or self._gen_channel()
4944
self.stub = self.stub_class(self._channel)
5045

51-
self._load_plugins()
52-
53-
def _load_plugins(self):
54-
"""@private"""
55-
try:
56-
57-
def stub_openapi_client_builder(*args, **kwargs):
58-
pass
59-
60-
install_plugins(self, stub_openapi_client_builder)
61-
except Exception as e:
62-
_logger.error(f"Error loading plugins in GRPCIndex: {e}")
63-
6446
@property
6547
def threadpool_executor(self):
6648
if self._pool is None:

pinecone/utils/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
from .docslinks import docslinks
1212
from .repr_overrides import install_json_repr_override
1313
from .error_handling import validate_and_convert_errors
14+
from .plugin_aware import PluginAware
1415

1516
__all__ = [
17+
"PluginAware",
1618
"check_kwargs",
1719
"__version__",
1820
"get_user_agent",

pinecone/utils/plugin_aware.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from .setup_openapi_client import build_plugin_setup_client
2+
from pinecone_plugin_interface import load_and_install as install_plugins
3+
import logging
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class PluginAware:
9+
def load_plugins(self):
10+
"""@private"""
11+
try:
12+
# I don't expect this to ever throw, but wrapping this in a
13+
# try block just in case to make sure a bad plugin doesn't
14+
# halt client initialization.
15+
openapi_client_builder = build_plugin_setup_client(
16+
config=self.config,
17+
openapi_config=self.openapi_config,
18+
pool_threads=self.pool_threads,
19+
)
20+
install_plugins(self, openapi_client_builder)
21+
except Exception as e:
22+
logger.error(f"Error loading plugins: {e}")

tests/unit/test_control.py

+3-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
ServerlessSpec as ServerlessSpecOpenApi,
2121
IndexModelStatus,
2222
)
23+
from pinecone.utils import PluginAware
24+
2325
from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi
2426

2527
import time
@@ -78,19 +80,10 @@ def index_list_response():
7880

7981
class TestControl:
8082
def test_plugins_are_installed(self):
81-
with patch("pinecone.control.pinecone.install_plugins") as mock_install_plugins:
83+
with patch.object(PluginAware, "load_plugins") as mock_install_plugins:
8284
Pinecone(api_key="asdf")
8385
mock_install_plugins.assert_called_once()
8486

85-
def test_bad_plugin_doesnt_break_sdk(self):
86-
with patch(
87-
"pinecone.control.pinecone.install_plugins", side_effect=Exception("bad plugin")
88-
):
89-
try:
90-
Pinecone(api_key="asdf")
91-
except Exception as e:
92-
assert False, f"Unexpected exception: {e}"
93-
9487
def test_default_host(self):
9588
p = Pinecone(api_key="123-456-789")
9689
assert p.index_api.api_client.configuration.host == "https://api.pinecone.io"

0 commit comments

Comments
 (0)