From f2b1de0778830dfdf19d869146f50be5cad7919d Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Wed, 29 Jan 2025 00:49:49 -0500 Subject: [PATCH] Organize enum types into enums subpackage --- pinecone/__init__.py | 1 + pinecone/control/pinecone.py | 38 ++++++++---- pinecone/enums/__init__.py | 18 ++++++ pinecone/enums/clouds.py | 22 +++++++ pinecone/enums/deletion_protection.py | 6 ++ pinecone/enums/metric.py | 11 ++++ .../pod_index_environment.py} | 25 ++------ pinecone/enums/pod_type.py | 20 ++++++ pinecone/enums/vector_type.py | 14 +++++ pinecone/models/__init__.py | 9 +-- pinecone/models/pod_spec.py | 62 ++++++++++++------- pinecone/models/serverless_spec.py | 2 +- ...est_configure_index_deletion_protection.py | 11 +++- .../test_create_index_sl_happy_path.py | 8 ++- tests/unit/test_control.py | 6 ++ 15 files changed, 182 insertions(+), 71 deletions(-) create mode 100644 pinecone/enums/__init__.py create mode 100644 pinecone/enums/clouds.py create mode 100644 pinecone/enums/deletion_protection.py create mode 100644 pinecone/enums/metric.py rename pinecone/{models/clouds.py => enums/pod_index_environment.py} (61%) create mode 100644 pinecone/enums/pod_type.py create mode 100644 pinecone/enums/vector_type.py diff --git a/pinecone/__init__.py b/pinecone/__init__.py index 01472989..2abbef3c 100644 --- a/pinecone/__init__.py +++ b/pinecone/__init__.py @@ -14,6 +14,7 @@ from .control import * from .data import * from .models import * +from .enums import * from .utils import __version__ diff --git a/pinecone/control/pinecone.py b/pinecone/control/pinecone.py index 861403ec..b42b9371 100644 --- a/pinecone/control/pinecone.py +++ b/pinecone/control/pinecone.py @@ -1,6 +1,6 @@ import time import logging -from typing import Optional, Dict, Any, Union, Literal +from typing import Optional, Dict, Any, Union from .index_host_store import IndexHostStore from .pinecone_interface import PineconeDBControlInterface @@ -18,7 +18,7 @@ ConfigureIndexRequest, ConfigureIndexRequestSpec, ConfigureIndexRequestSpecPod, - DeletionProtection, + DeletionProtection as DeletionProtectionModel, IndexSpec, IndexTags, ServerlessSpec as ServerlessSpecModel, @@ -31,6 +31,7 @@ from pinecone.utils import parse_non_empty_args, docslinks from pinecone.data import _Index, _AsyncioIndex, _Inference +from pinecone.enums import Metric, VectorType, DeletionProtection, PodType from pinecone_plugin_interface import load_and_install as install_plugins @@ -179,19 +180,26 @@ def create_index( name: str, spec: Union[Dict, ServerlessSpec, PodSpec], dimension: Optional[int] = None, - metric: Optional[Literal["cosine", "euclidean", "dotproduct"]] = "cosine", + metric: Optional[Union[Metric, str]] = Metric.COSINE, timeout: Optional[int] = None, - deletion_protection: Optional[Literal["enabled", "disabled"]] = "disabled", - vector_type: Optional[Literal["dense", "sparse"]] = "dense", + deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED, + vector_type: Optional[Union[VectorType, str]] = VectorType.DENSE, tags: Optional[Dict[str, str]] = None, ): - api_instance = self.index_api + # Convert Enums to their string values if necessary + metric = metric.value if isinstance(metric, Metric) else str(metric) + vector_type = vector_type.value if isinstance(vector_type, VectorType) else str(vector_type) + deletion_protection = ( + deletion_protection.value + if isinstance(deletion_protection, DeletionProtection) + else str(deletion_protection) + ) - if vector_type == "sparse" and dimension is not None: + if vector_type == VectorType.SPARSE.value and dimension is not None: raise ValueError("dimension should not be specified for sparse indexes") if deletion_protection in ["enabled", "disabled"]: - dp = DeletionProtection(deletion_protection) + dp = DeletionProtectionModel(deletion_protection) else: raise ValueError("deletion_protection must be either 'enabled' or 'disabled'") @@ -202,6 +210,7 @@ def create_index( index_spec = self._parse_index_spec(spec) + api_instance = self.index_api api_instance.create_index( create_index_request=CreateIndexRequest( **parse_non_empty_args( @@ -301,17 +310,19 @@ def configure_index( self, name: str, replicas: Optional[int] = None, - pod_type: Optional[str] = None, - deletion_protection: Optional[Literal["enabled", "disabled"]] = None, + pod_type: Optional[Union[PodType, str]] = None, + deletion_protection: Optional[Union[DeletionProtection, str]] = None, tags: Optional[Dict[str, str]] = None, ): api_instance = self.index_api description = self.describe_index(name=name) if deletion_protection is None: - dp = DeletionProtection(description.deletion_protection) + dp = DeletionProtectionModel(description.deletion_protection) + elif isinstance(deletion_protection, DeletionProtection): + dp = DeletionProtectionModel(deletion_protection.value) elif deletion_protection in ["enabled", "disabled"]: - dp = DeletionProtection(deletion_protection) + dp = DeletionProtectionModel(deletion_protection) else: raise ValueError("deletion_protection must be either 'enabled' or 'disabled'") @@ -330,7 +341,8 @@ def configure_index( pod_config_args: Dict[str, Any] = {} if pod_type: - pod_config_args.update(pod_type=pod_type) + new_pod_type = pod_type.value if isinstance(pod_type, PodType) else pod_type + pod_config_args.update(pod_type=new_pod_type) if replicas: pod_config_args.update(replicas=replicas) diff --git a/pinecone/enums/__init__.py b/pinecone/enums/__init__.py new file mode 100644 index 00000000..38d11eb3 --- /dev/null +++ b/pinecone/enums/__init__.py @@ -0,0 +1,18 @@ +from .clouds import CloudProvider, AwsRegion, GcpRegion, AzureRegion +from .deletion_protection import DeletionProtection +from .metric import Metric +from .pod_index_environment import PodIndexEnvironment +from .pod_type import PodType +from .vector_type import VectorType + +__all__ = [ + "CloudProvider", + "AwsRegion", + "GcpRegion", + "AzureRegion", + "DeletionProtection", + "Metric", + "PodIndexEnvironment", + "PodType", + "VectorType", +] diff --git a/pinecone/enums/clouds.py b/pinecone/enums/clouds.py new file mode 100644 index 00000000..6ee9f0fa --- /dev/null +++ b/pinecone/enums/clouds.py @@ -0,0 +1,22 @@ +from enum import Enum + + +class CloudProvider(Enum): + AWS = "aws" + GCP = "gcp" + AZURE = "azure" + + +class AwsRegion(Enum): + US_EAST_1 = "us-east-1" + US_WEST_2 = "us-west-2" + EU_WEST_1 = "eu-west-1" + + +class GcpRegion(Enum): + US_CENTRAL1 = "us-central1" + EUROPE_WEST4 = "europe-west4" + + +class AzureRegion(Enum): + EAST_US = "eastus2" diff --git a/pinecone/enums/deletion_protection.py b/pinecone/enums/deletion_protection.py new file mode 100644 index 00000000..dde4f261 --- /dev/null +++ b/pinecone/enums/deletion_protection.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class DeletionProtection(Enum): + ENABLED = "enabled" + DISABLED = "disabled" diff --git a/pinecone/enums/metric.py b/pinecone/enums/metric.py new file mode 100644 index 00000000..c3488220 --- /dev/null +++ b/pinecone/enums/metric.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class Metric(Enum): + """ + The metric specifies how Pinecone should calculate the distance between vectors when querying an index. + """ + + COSINE = "cosine" + EUCLIDEAN = "euclidean" + DOTPRODUCT = "dotproduct" diff --git a/pinecone/models/clouds.py b/pinecone/enums/pod_index_environment.py similarity index 61% rename from pinecone/models/clouds.py rename to pinecone/enums/pod_index_environment.py index 81c467b9..13b9fa87 100644 --- a/pinecone/models/clouds.py +++ b/pinecone/enums/pod_index_environment.py @@ -1,28 +1,11 @@ from enum import Enum -class CloudProvider(Enum): - AWS = "aws" - GCP = "gcp" - AZURE = "azure" - - -class AwsRegion(Enum): - US_EAST_1 = "us-east-1" - US_WEST_2 = "us-west-2" - EU_WEST_1 = "eu-west-1" - - -class GcpRegion(Enum): - US_CENTRAL1 = "us-central1" - EUROPE_WEST4 = "europe-west4" - - -class AzureRegion(Enum): - EAST_US = "eastus2" - - class PodIndexEnvironment(Enum): + """ + These environment strings are used to specify where a pod index should be deployed. + """ + US_WEST1_GCP = "us-west1-gcp" US_CENTRAL1_GCP = "us-central1-gcp" US_WEST4_GCP = "us-west4-gcp" diff --git a/pinecone/enums/pod_type.py b/pinecone/enums/pod_type.py new file mode 100644 index 00000000..2e947c85 --- /dev/null +++ b/pinecone/enums/pod_type.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class PodType(Enum): + """ + PodType represents the available pod types for a pod index. + """ + + P1_X1 = "p1.x1" + P1_X2 = "p1.x2" + P1_X4 = "p1.x4" + P1_X8 = "p1.x8" + S1_X1 = "s1.x1" + S1_X2 = "s1.x2" + S1_X4 = "s1.x4" + S1_X8 = "s1.x8" + P2_X1 = "p2.x1" + P2_X2 = "p2.x2" + P2_X4 = "p2.x4" + P2_X8 = "p2.x8" diff --git a/pinecone/enums/vector_type.py b/pinecone/enums/vector_type.py new file mode 100644 index 00000000..cc3d6f31 --- /dev/null +++ b/pinecone/enums/vector_type.py @@ -0,0 +1,14 @@ +from enum import Enum + + +class VectorType(Enum): + """ + VectorType is used to specifiy the type of vector you will store in the index. + + Dense vectors are used to store dense embeddings, which are vectors with non-zero values in most of the dimensions. + + Sparse vectors are used to store sparse embeddings, which allow vectors with zero values in most of the dimensions to be represented concisely. + """ + + DENSE = "dense" + SPARSE = "sparse" diff --git a/pinecone/models/__init__.py b/pinecone/models/__init__.py index 12d7cdea..e9b187ea 100644 --- a/pinecone/models/__init__.py +++ b/pinecone/models/__init__.py @@ -1,11 +1,11 @@ from .index_description import ServerlessSpecDefinition, PodSpecDefinition from .collection_description import CollectionDescription from .serverless_spec import ServerlessSpec -from .pod_spec import PodSpec +from .pod_spec import PodSpec, PodType from .index_list import IndexList from .collection_list import CollectionList from .index_model import IndexModel -from .clouds import CloudProvider, AwsRegion, GcpRegion, AzureRegion, PodIndexEnvironment +from ..enums.metric import Metric __all__ = [ "CollectionDescription", @@ -16,9 +16,4 @@ "IndexList", "CollectionList", "IndexModel", - "CloudProvider", - "AwsRegion", - "GcpRegion", - "AzureRegion", - "PodIndexEnvironment", ] diff --git a/pinecone/models/pod_spec.py b/pinecone/models/pod_spec.py index 2f37fd81..2e6a41b9 100644 --- a/pinecone/models/pod_spec.py +++ b/pinecone/models/pod_spec.py @@ -1,7 +1,11 @@ -from typing import NamedTuple, Optional, Dict +from dataclasses import dataclass, field +from typing import Optional, Dict, Union +from ..enums import PodIndexEnvironment, PodType -class PodSpec(NamedTuple): + +@dataclass(frozen=True) +class PodSpec: """ PodSpec represents the configuration used to deploy a pod-based index. @@ -33,32 +37,16 @@ class PodSpec(NamedTuple): This value combines pod type and pod size into a single string. This configuration is your main lever for vertical scaling. """ - metadata_config: Optional[Dict] = {} + metadata_config: Optional[Dict] = field(default_factory=dict) """ - If you are storing a lot of metadata, you can use this configuration to limit the fields which are indexed for search. + If you are storing a lot of metadata, you can use this configuration to limit the fields which are indexed for search. This configuration should be a dictionary with the key 'indexed' and the value as a list of fields to index. - For example, if your vectors have metadata along like this: - - ```python - from pinecone import Vector - - vector = Vector( - id='237438191', - values=[...], - metadata={ - 'productId': '237438191', - 'description': 'Stainless Steel Tumbler with Straw', - 'category': 'kitchen', - 'price': '19.99' - } - ) - ``` - - You might want to limit which fields are indexed with metadata config such as this: + Example: ``` {'indexed': ['field1', 'field2']} + ``` """ source_collection: Optional[str] = None @@ -66,8 +54,34 @@ class PodSpec(NamedTuple): The name of the collection to use as the source for the pod index. This configuration is only used when creating a pod index from an existing collection. """ - def asdict(self): + def __init__( + self, + environment: Union[PodIndexEnvironment, str], + pod_type: Union[PodType, str] = "p1.x1", + replicas: Optional[int] = None, + shards: Optional[int] = None, + pods: Optional[int] = None, + metadata_config: Optional[Dict] = None, + source_collection: Optional[str] = None, + ): + object.__setattr__( + self, + "environment", + environment.value if isinstance(environment, PodIndexEnvironment) else str(environment), + ) + object.__setattr__( + self, "pod_type", pod_type.value if isinstance(pod_type, PodType) else str(pod_type) + ) + object.__setattr__(self, "replicas", replicas) + object.__setattr__(self, "shards", shards) + object.__setattr__(self, "pods", pods) + object.__setattr__( + self, "metadata_config", metadata_config if metadata_config is not None else {} + ) + object.__setattr__(self, "source_collection", source_collection) + + def asdict(self) -> Dict: """ Returns the PodSpec as a dictionary. """ - return {"pod": self._asdict()} + return {"pod": self.__dict__} diff --git a/pinecone/models/serverless_spec.py b/pinecone/models/serverless_spec.py index e4582fe6..1fc51564 100644 --- a/pinecone/models/serverless_spec.py +++ b/pinecone/models/serverless_spec.py @@ -2,7 +2,7 @@ from typing import Union from enum import Enum -from .clouds import CloudProvider, AwsRegion, GcpRegion, AzureRegion +from ..enums import CloudProvider, AwsRegion, GcpRegion, AzureRegion @dataclass(frozen=True) diff --git a/tests/integration/control/serverless/test_configure_index_deletion_protection.py b/tests/integration/control/serverless/test_configure_index_deletion_protection.py index 9d6c3c4e..fb12897a 100644 --- a/tests/integration/control/serverless/test_configure_index_deletion_protection.py +++ b/tests/integration/control/serverless/test_configure_index_deletion_protection.py @@ -1,10 +1,15 @@ import pytest +from pinecone import DeletionProtection class TestDeletionProtection: - def test_deletion_protection(self, client, create_sl_index_params): + @pytest.mark.parametrize( + "dp_enabled,dp_disabled", + [("enabled", "disabled"), (DeletionProtection.ENABLED, DeletionProtection.DISABLED)], + ) + def test_deletion_protection(self, client, create_sl_index_params, dp_enabled, dp_disabled): name = create_sl_index_params["name"] - client.create_index(**create_sl_index_params, deletion_protection="enabled") + client.create_index(**create_sl_index_params, deletion_protection=dp_enabled) desc = client.describe_index(name) assert desc.deletion_protection == "enabled" @@ -12,7 +17,7 @@ def test_deletion_protection(self, client, create_sl_index_params): client.delete_index(name) assert "Deletion protection is enabled for this index" in str(e.value) - client.configure_index(name, deletion_protection="disabled") + client.configure_index(name, deletion_protection=dp_disabled) desc = client.describe_index(name) assert desc.deletion_protection == "disabled" diff --git a/tests/integration/control/serverless/test_create_index_sl_happy_path.py b/tests/integration/control/serverless/test_create_index_sl_happy_path.py index eb587587..69aeb0e8 100644 --- a/tests/integration/control/serverless/test_create_index_sl_happy_path.py +++ b/tests/integration/control/serverless/test_create_index_sl_happy_path.py @@ -1,4 +1,5 @@ import pytest +from pinecone import Metric, VectorType class TestCreateSLIndexHappyPath: @@ -13,7 +14,10 @@ def test_create_index(self, client, create_sl_index_params): assert desc.deletion_protection == "disabled" # default value assert desc.vector_type == "dense" # default value - @pytest.mark.parametrize("metric", ["cosine", "euclidean", "dotproduct"]) + @pytest.mark.parametrize( + "metric", + ["cosine", "euclidean", "dotproduct", Metric.COSINE, Metric.EUCLIDEAN, Metric.DOTPRODUCT], + ) def test_create_default_index_with_metric(self, client, create_sl_index_params, metric): create_sl_index_params["metric"] = metric client.create_index(**create_sl_index_params) @@ -24,7 +28,7 @@ def test_create_default_index_with_metric(self, client, create_sl_index_params, @pytest.mark.parametrize("metric", ["cosine", "euclidean", "dotproduct"]) def test_create_dense_index_with_metric(self, client, create_sl_index_params, metric): create_sl_index_params["metric"] = metric - create_sl_index_params["vector_type"] = "dense" + create_sl_index_params["vector_type"] = VectorType.DENSE client.create_index(**create_sl_index_params) desc = client.describe_index(create_sl_index_params["name"]) assert desc.metric == metric diff --git a/tests/unit/test_control.py b/tests/unit/test_control.py index f11a462b..1115948d 100644 --- a/tests/unit/test_control.py +++ b/tests/unit/test_control.py @@ -9,6 +9,8 @@ CloudProvider, AwsRegion, GcpRegion, + PodIndexEnvironment, + PodType, ) from pinecone.core.openapi.db_control.models import IndexList, IndexModel, DeletionProtection from pinecone.core.openapi.db_control.api.manage_indexes_api import ManageIndexesApi @@ -167,6 +169,10 @@ def test_create_index_with_timeout( {"serverless": {"cloud": "aws", "region": "us-west1"}}, {"serverless": {"cloud": "aws", "region": "us-west1", "uknown_key": "value"}}, PodSpec(environment="us-west1-gcp", pod_type="p1.x1"), + PodSpec(environment=PodIndexEnvironment.US_WEST1_GCP, pod_type=PodType.P2_X2), + PodSpec(environment=PodIndexEnvironment.US_WEST1_GCP, pod_type="s1.x4"), + PodSpec(environment=PodIndexEnvironment.US_EAST1_AWS, pod_type="unknown-pod-type"), + PodSpec(environment="us-west1-gcp", pod_type="p1.x1", pods=2, replicas=1, shards=1), {"pod": {"environment": "us-west1-gcp", "pod_type": "p1.x1"}}, {"pod": {"environment": "us-west1-gcp", "pod_type": "p1.x1", "unknown_key": "value"}}, {