Skip to content

Commit

Permalink
Organize enum types into enums subpackage
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Jan 29, 2025
1 parent 66ae950 commit f2b1de0
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 71 deletions.
1 change: 1 addition & 0 deletions pinecone/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .control import *
from .data import *
from .models import *
from .enums import *

from .utils import __version__

Expand Down
38 changes: 25 additions & 13 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import logging
from typing import Optional, Dict, Any, Union, Literal
from typing import Optional, Dict, Any, Union

from .index_host_store import IndexHostStore
from .pinecone_interface import PineconeDBControlInterface
Expand All @@ -18,7 +18,7 @@
ConfigureIndexRequest,
ConfigureIndexRequestSpec,
ConfigureIndexRequestSpecPod,
DeletionProtection,
DeletionProtection as DeletionProtectionModel,
IndexSpec,
IndexTags,
ServerlessSpec as ServerlessSpecModel,
Expand All @@ -31,6 +31,7 @@
from pinecone.utils import parse_non_empty_args, docslinks

from pinecone.data import _Index, _AsyncioIndex, _Inference
from pinecone.enums import Metric, VectorType, DeletionProtection, PodType

from pinecone_plugin_interface import load_and_install as install_plugins

Expand Down Expand Up @@ -179,19 +180,26 @@ def create_index(
name: str,
spec: Union[Dict, ServerlessSpec, PodSpec],
dimension: Optional[int] = None,
metric: Optional[Literal["cosine", "euclidean", "dotproduct"]] = "cosine",
metric: Optional[Union[Metric, str]] = Metric.COSINE,
timeout: Optional[int] = None,
deletion_protection: Optional[Literal["enabled", "disabled"]] = "disabled",
vector_type: Optional[Literal["dense", "sparse"]] = "dense",
deletion_protection: Optional[Union[DeletionProtection, str]] = DeletionProtection.DISABLED,
vector_type: Optional[Union[VectorType, str]] = VectorType.DENSE,
tags: Optional[Dict[str, str]] = None,
):
api_instance = self.index_api
# Convert Enums to their string values if necessary
metric = metric.value if isinstance(metric, Metric) else str(metric)
vector_type = vector_type.value if isinstance(vector_type, VectorType) else str(vector_type)
deletion_protection = (
deletion_protection.value
if isinstance(deletion_protection, DeletionProtection)
else str(deletion_protection)
)

if vector_type == "sparse" and dimension is not None:
if vector_type == VectorType.SPARSE.value and dimension is not None:
raise ValueError("dimension should not be specified for sparse indexes")

if deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtection(deletion_protection)
dp = DeletionProtectionModel(deletion_protection)
else:
raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")

Expand All @@ -202,6 +210,7 @@ def create_index(

index_spec = self._parse_index_spec(spec)

api_instance = self.index_api
api_instance.create_index(
create_index_request=CreateIndexRequest(
**parse_non_empty_args(
Expand Down Expand Up @@ -301,17 +310,19 @@ def configure_index(
self,
name: str,
replicas: Optional[int] = None,
pod_type: Optional[str] = None,
deletion_protection: Optional[Literal["enabled", "disabled"]] = None,
pod_type: Optional[Union[PodType, str]] = None,
deletion_protection: Optional[Union[DeletionProtection, str]] = None,
tags: Optional[Dict[str, str]] = None,
):
api_instance = self.index_api
description = self.describe_index(name=name)

if deletion_protection is None:
dp = DeletionProtection(description.deletion_protection)
dp = DeletionProtectionModel(description.deletion_protection)
elif isinstance(deletion_protection, DeletionProtection):
dp = DeletionProtectionModel(deletion_protection.value)
elif deletion_protection in ["enabled", "disabled"]:
dp = DeletionProtection(deletion_protection)
dp = DeletionProtectionModel(deletion_protection)
else:
raise ValueError("deletion_protection must be either 'enabled' or 'disabled'")

Expand All @@ -330,7 +341,8 @@ def configure_index(

pod_config_args: Dict[str, Any] = {}
if pod_type:
pod_config_args.update(pod_type=pod_type)
new_pod_type = pod_type.value if isinstance(pod_type, PodType) else pod_type
pod_config_args.update(pod_type=new_pod_type)
if replicas:
pod_config_args.update(replicas=replicas)

Expand Down
18 changes: 18 additions & 0 deletions pinecone/enums/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .clouds import CloudProvider, AwsRegion, GcpRegion, AzureRegion
from .deletion_protection import DeletionProtection
from .metric import Metric
from .pod_index_environment import PodIndexEnvironment
from .pod_type import PodType
from .vector_type import VectorType

__all__ = [
"CloudProvider",
"AwsRegion",
"GcpRegion",
"AzureRegion",
"DeletionProtection",
"Metric",
"PodIndexEnvironment",
"PodType",
"VectorType",
]
22 changes: 22 additions & 0 deletions pinecone/enums/clouds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from enum import Enum


class CloudProvider(Enum):
AWS = "aws"
GCP = "gcp"
AZURE = "azure"


class AwsRegion(Enum):
US_EAST_1 = "us-east-1"
US_WEST_2 = "us-west-2"
EU_WEST_1 = "eu-west-1"


class GcpRegion(Enum):
US_CENTRAL1 = "us-central1"
EUROPE_WEST4 = "europe-west4"


class AzureRegion(Enum):
EAST_US = "eastus2"
6 changes: 6 additions & 0 deletions pinecone/enums/deletion_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class DeletionProtection(Enum):
ENABLED = "enabled"
DISABLED = "disabled"
11 changes: 11 additions & 0 deletions pinecone/enums/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum


class Metric(Enum):
"""
The metric specifies how Pinecone should calculate the distance between vectors when querying an index.
"""

COSINE = "cosine"
EUCLIDEAN = "euclidean"
DOTPRODUCT = "dotproduct"
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
from enum import Enum


class CloudProvider(Enum):
AWS = "aws"
GCP = "gcp"
AZURE = "azure"


class AwsRegion(Enum):
US_EAST_1 = "us-east-1"
US_WEST_2 = "us-west-2"
EU_WEST_1 = "eu-west-1"


class GcpRegion(Enum):
US_CENTRAL1 = "us-central1"
EUROPE_WEST4 = "europe-west4"


class AzureRegion(Enum):
EAST_US = "eastus2"


class PodIndexEnvironment(Enum):
"""
These environment strings are used to specify where a pod index should be deployed.
"""

US_WEST1_GCP = "us-west1-gcp"
US_CENTRAL1_GCP = "us-central1-gcp"
US_WEST4_GCP = "us-west4-gcp"
Expand Down
20 changes: 20 additions & 0 deletions pinecone/enums/pod_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import Enum


class PodType(Enum):
"""
PodType represents the available pod types for a pod index.
"""

P1_X1 = "p1.x1"
P1_X2 = "p1.x2"
P1_X4 = "p1.x4"
P1_X8 = "p1.x8"
S1_X1 = "s1.x1"
S1_X2 = "s1.x2"
S1_X4 = "s1.x4"
S1_X8 = "s1.x8"
P2_X1 = "p2.x1"
P2_X2 = "p2.x2"
P2_X4 = "p2.x4"
P2_X8 = "p2.x8"
14 changes: 14 additions & 0 deletions pinecone/enums/vector_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum


class VectorType(Enum):
"""
VectorType is used to specifiy the type of vector you will store in the index.
Dense vectors are used to store dense embeddings, which are vectors with non-zero values in most of the dimensions.
Sparse vectors are used to store sparse embeddings, which allow vectors with zero values in most of the dimensions to be represented concisely.
"""

DENSE = "dense"
SPARSE = "sparse"
9 changes: 2 additions & 7 deletions pinecone/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .index_description import ServerlessSpecDefinition, PodSpecDefinition
from .collection_description import CollectionDescription
from .serverless_spec import ServerlessSpec
from .pod_spec import PodSpec
from .pod_spec import PodSpec, PodType
from .index_list import IndexList
from .collection_list import CollectionList
from .index_model import IndexModel
from .clouds import CloudProvider, AwsRegion, GcpRegion, AzureRegion, PodIndexEnvironment
from ..enums.metric import Metric

__all__ = [
"CollectionDescription",
Expand All @@ -16,9 +16,4 @@
"IndexList",
"CollectionList",
"IndexModel",
"CloudProvider",
"AwsRegion",
"GcpRegion",
"AzureRegion",
"PodIndexEnvironment",
]
62 changes: 38 additions & 24 deletions pinecone/models/pod_spec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import NamedTuple, Optional, Dict
from dataclasses import dataclass, field
from typing import Optional, Dict, Union

from ..enums import PodIndexEnvironment, PodType

class PodSpec(NamedTuple):

@dataclass(frozen=True)
class PodSpec:
"""
PodSpec represents the configuration used to deploy a pod-based index.
Expand Down Expand Up @@ -33,41 +37,51 @@ class PodSpec(NamedTuple):
This value combines pod type and pod size into a single string. This configuration is your main lever for vertical scaling.
"""

metadata_config: Optional[Dict] = {}
metadata_config: Optional[Dict] = field(default_factory=dict)
"""
If you are storing a lot of metadata, you can use this configuration to limit the fields which are indexed for search.
If you are storing a lot of metadata, you can use this configuration to limit the fields which are indexed for search.
This configuration should be a dictionary with the key 'indexed' and the value as a list of fields to index.
For example, if your vectors have metadata along like this:
```python
from pinecone import Vector
vector = Vector(
id='237438191',
values=[...],
metadata={
'productId': '237438191',
'description': 'Stainless Steel Tumbler with Straw',
'category': 'kitchen',
'price': '19.99'
}
)
```
You might want to limit which fields are indexed with metadata config such as this:
Example:
```
{'indexed': ['field1', 'field2']}
```
"""

source_collection: Optional[str] = None
"""
The name of the collection to use as the source for the pod index. This configuration is only used when creating a pod index from an existing collection.
"""

def asdict(self):
def __init__(
self,
environment: Union[PodIndexEnvironment, str],
pod_type: Union[PodType, str] = "p1.x1",
replicas: Optional[int] = None,
shards: Optional[int] = None,
pods: Optional[int] = None,
metadata_config: Optional[Dict] = None,
source_collection: Optional[str] = None,
):
object.__setattr__(
self,
"environment",
environment.value if isinstance(environment, PodIndexEnvironment) else str(environment),
)
object.__setattr__(
self, "pod_type", pod_type.value if isinstance(pod_type, PodType) else str(pod_type)
)
object.__setattr__(self, "replicas", replicas)
object.__setattr__(self, "shards", shards)
object.__setattr__(self, "pods", pods)
object.__setattr__(
self, "metadata_config", metadata_config if metadata_config is not None else {}
)
object.__setattr__(self, "source_collection", source_collection)

def asdict(self) -> Dict:
"""
Returns the PodSpec as a dictionary.
"""
return {"pod": self._asdict()}
return {"pod": self.__dict__}
2 changes: 1 addition & 1 deletion pinecone/models/serverless_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union
from enum import Enum

from .clouds import CloudProvider, AwsRegion, GcpRegion, AzureRegion
from ..enums import CloudProvider, AwsRegion, GcpRegion, AzureRegion


@dataclass(frozen=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import pytest
from pinecone import DeletionProtection


class TestDeletionProtection:
def test_deletion_protection(self, client, create_sl_index_params):
@pytest.mark.parametrize(
"dp_enabled,dp_disabled",
[("enabled", "disabled"), (DeletionProtection.ENABLED, DeletionProtection.DISABLED)],
)
def test_deletion_protection(self, client, create_sl_index_params, dp_enabled, dp_disabled):
name = create_sl_index_params["name"]
client.create_index(**create_sl_index_params, deletion_protection="enabled")
client.create_index(**create_sl_index_params, deletion_protection=dp_enabled)
desc = client.describe_index(name)
assert desc.deletion_protection == "enabled"

with pytest.raises(Exception) as e:
client.delete_index(name)
assert "Deletion protection is enabled for this index" in str(e.value)

client.configure_index(name, deletion_protection="disabled")
client.configure_index(name, deletion_protection=dp_disabled)
desc = client.describe_index(name)
assert desc.deletion_protection == "disabled"

Expand Down
Loading

0 comments on commit f2b1de0

Please sign in to comment.