Skip to content

Commit

Permalink
[refactor] Introduce PluginAware utility class (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon authored Jan 31, 2025
1 parent 6e81b12 commit 980ac3b
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 135 deletions.
4 changes: 2 additions & 2 deletions pinecone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import logging

# Raise an exception if the user is attempting to use the SDK with deprecated plugins
# installed in their project.
# Raise an exception if the user is attempting to use the SDK with
# deprecated plugins installed in their project.
check_for_deprecated_plugins()

# Silence annoying log messages from the plugin interface
Expand Down
21 changes: 2 additions & 19 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pinecone.openapi_support.api_client import ApiClient


from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import normalize_host, setup_openapi_client, PluginAware
from pinecone.core.openapi.db_control import API_VERSION
from pinecone.models import (
ServerlessSpec,
Expand All @@ -38,13 +38,11 @@
from .types import CreateIndexForModelEmbedTypedDict
from .request_factory import PineconeDBControlRequestFactory

from pinecone_plugin_interface import load_and_install as install_plugins

logger = logging.getLogger(__name__)
""" @private """


class Pinecone(PineconeDBControlInterface):
class Pinecone(PineconeDBControlInterface, PluginAware):
"""
A client for interacting with Pinecone's vector database.
Expand Down Expand Up @@ -113,21 +111,6 @@ def inference(self):
self._inference = _Inference(config=self.config, openapi_config=self.openapi_config)
return self._inference

def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")

def create_index(
self,
name: str,
Expand Down
21 changes: 1 addition & 20 deletions pinecone/control/pinecone_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pinecone.core.openapi.db_control.api.manage_indexes_api import AsyncioManageIndexesApi
from pinecone.openapi_support import AsyncioApiClient

from pinecone.utils import normalize_host, setup_openapi_client, build_plugin_setup_client
from pinecone.utils import normalize_host, setup_openapi_client
from pinecone.core.openapi.db_control import API_VERSION
from pinecone.models import (
ServerlessSpec,
Expand All @@ -36,8 +36,6 @@
from .request_factory import PineconeDBControlRequestFactory
from .pinecone_interface_asyncio import PineconeAsyncioDBControlInterface

from pinecone_plugin_interface import load_and_install as install_plugins

logger = logging.getLogger(__name__)
""" @private """

Expand Down Expand Up @@ -104,8 +102,6 @@ def __init__(
self.index_host_store = IndexHostStore()
""" @private """

self.load_plugins()

async def __aenter__(self):
return self

Expand All @@ -122,21 +118,6 @@ def inference(self):
self._inference = _AsyncioInference(api_client=self.index_api.api_client)
return self._inference

def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")

async def create_index(
self,
name: str,
Expand Down
21 changes: 2 additions & 19 deletions pinecone/data/features/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from pinecone.core.openapi.inference.apis import InferenceApi
from pinecone.core.openapi.inference.models import EmbeddingsList, RerankResult
from pinecone.core.openapi.inference import API_VERSION
from pinecone.utils import setup_openapi_client, build_plugin_setup_client
from pinecone.utils import setup_openapi_client, PluginAware

from pinecone_plugin_interface import load_and_install as install_plugins

from .inference_request_builder import (
InferenceRequestBuilder,
Expand All @@ -18,7 +17,7 @@
logger = logging.getLogger(__name__)


class Inference:
class Inference(PluginAware):
"""
The `Inference` class configures and uses the Pinecone Inference API to generate embeddings and
rank documents.
Expand All @@ -43,24 +42,8 @@ def __init__(self, config, openapi_config, **kwargs):
pool_threads=kwargs.get("pool_threads", 1),
api_version=API_VERSION,
)

self.load_plugins()

def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")

def embed(
self,
model: Union[EmbedModelEnum, str],
Expand Down
30 changes: 7 additions & 23 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@
from ..utils import (
setup_openapi_client,
parse_non_empty_args,
build_plugin_setup_client,
validate_and_convert_errors,
PluginAware,
)
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS

from multiprocessing.pool import ApplyResult
from concurrent.futures import as_completed

from pinecone_plugin_interface import load_and_install as install_plugins

logger = logging.getLogger(__name__)

Expand All @@ -52,7 +51,7 @@ def parse_query_response(response: QueryResponse):
return response


class Index(IndexInterface, ImportFeatureMixin):
class Index(IndexInterface, ImportFeatureMixin, PluginAware):
"""
A client for interacting with a Pinecone index via REST API.
For improved performance, use the Pinecone GRPC index client.
Expand All @@ -70,17 +69,17 @@ def __init__(
self.config = ConfigBuilder.build(
api_key=api_key, host=host, additional_headers=additional_headers, **kwargs
)
self._openapi_config = ConfigBuilder.build_openapi_config(self.config, openapi_config)
self._pool_threads = pool_threads
self.openapi_config = ConfigBuilder.build_openapi_config(self.config, openapi_config)
self.pool_threads = pool_threads

if kwargs.get("connection_pool_maxsize", None):
self._openapi_config.connection_pool_maxsize = kwargs.get("connection_pool_maxsize")
self.openapi_config.connection_pool_maxsize = kwargs.get("connection_pool_maxsize")

self._vector_api = setup_openapi_client(
api_client_klass=ApiClient,
api_klass=VectorOperationsApi,
config=self.config,
openapi_config=self._openapi_config,
openapi_config=self.openapi_config,
pool_threads=pool_threads,
api_version=API_VERSION,
)
Expand All @@ -90,22 +89,7 @@ def __init__(
# Pass the same api_client to the ImportFeatureMixin
super().__init__(api_client=self._api_client)

self._load_plugins()

def _load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self._openapi_config,
pool_threads=self._pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins in Index: {e}")
self.load_plugins()

def _openapi_kwargs(self, kwargs):
return {k: v for k, v in kwargs.items() if k in OPENAPI_ENDPOINT_PARAMS}
Expand Down
26 changes: 2 additions & 24 deletions pinecone/data/index_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@
SearchRecordsResponse,
)

from ..utils import (
setup_openapi_client,
parse_non_empty_args,
build_plugin_setup_client,
validate_and_convert_errors,
)
from ..utils import setup_openapi_client, parse_non_empty_args, validate_and_convert_errors
from .types import (
SparseVectorTypedDict,
VectorTypedDict,
Expand All @@ -47,7 +42,7 @@
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryNamespacesResults
from .features.bulk_import import ImportFeatureMixinAsyncio
from pinecone_plugin_interface import load_and_install as install_plugins


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -107,23 +102,6 @@ def __init__(
# This is important for async context management to work correctly
super().__init__(api_client=self._api_client)

self._load_plugins()

def _load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self._openapi_config,
pool_threads=self._pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins in Index: {e}")

async def __aenter__(self):
return self

Expand Down
18 changes: 0 additions & 18 deletions pinecone/grpc/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
from typing import Optional

import logging
import grpc
from grpc._channel import Channel

Expand All @@ -12,10 +11,6 @@
from .grpc_runner import GrpcRunner
from concurrent.futures import ThreadPoolExecutor

from pinecone_plugin_interface import load_and_install as install_plugins

_logger = logging.getLogger(__name__)


class GRPCIndexBase(ABC):
"""
Expand Down Expand Up @@ -48,19 +43,6 @@ def __init__(
self._channel = channel or self._gen_channel()
self.stub = self.stub_class(self._channel)

self._load_plugins()

def _load_plugins(self):
"""@private"""
try:

def stub_openapi_client_builder(*args, **kwargs):
pass

install_plugins(self, stub_openapi_client_builder)
except Exception as e:
_logger.error(f"Error loading plugins in GRPCIndex: {e}")

@property
def threadpool_executor(self):
if self._pool is None:
Expand Down
2 changes: 2 additions & 0 deletions pinecone/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from .docslinks import docslinks
from .repr_overrides import install_json_repr_override
from .error_handling import validate_and_convert_errors
from .plugin_aware import PluginAware

__all__ = [
"PluginAware",
"check_kwargs",
"__version__",
"get_user_agent",
Expand Down
22 changes: 22 additions & 0 deletions pinecone/utils/plugin_aware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .setup_openapi_client import build_plugin_setup_client
from pinecone_plugin_interface import load_and_install as install_plugins
import logging

logger = logging.getLogger(__name__)


class PluginAware:
def load_plugins(self):
"""@private"""
try:
# I don't expect this to ever throw, but wrapping this in a
# try block just in case to make sure a bad plugin doesn't
# halt client initialization.
openapi_client_builder = build_plugin_setup_client(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=self.pool_threads,
)
install_plugins(self, openapi_client_builder)
except Exception as e:
logger.error(f"Error loading plugins: {e}")
13 changes: 3 additions & 10 deletions tests/unit/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
ServerlessSpec as ServerlessSpecOpenApi,
IndexModelStatus,
)
from pinecone.utils import PluginAware

from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi

import time
Expand Down Expand Up @@ -78,19 +80,10 @@ def index_list_response():

class TestControl:
def test_plugins_are_installed(self):
with patch("pinecone.control.pinecone.install_plugins") as mock_install_plugins:
with patch.object(PluginAware, "load_plugins") as mock_install_plugins:
Pinecone(api_key="asdf")
mock_install_plugins.assert_called_once()

def test_bad_plugin_doesnt_break_sdk(self):
with patch(
"pinecone.control.pinecone.install_plugins", side_effect=Exception("bad plugin")
):
try:
Pinecone(api_key="asdf")
except Exception as e:
assert False, f"Unexpected exception: {e}"

def test_default_host(self):
p = Pinecone(api_key="123-456-789")
assert p.index_api.api_client.configuration.host == "https://api.pinecone.io"
Expand Down

0 comments on commit 980ac3b

Please sign in to comment.