Skip to content

Commit

Permalink
Add deletion protection
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Jul 18, 2024
1 parent d31b072 commit 2c48823
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 81 deletions.
64 changes: 40 additions & 24 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import warnings
import logging
from typing import Optional, Dict, Any, Union, List, Tuple, cast, NamedTuple
from typing import Optional, Dict, Any, Union, List, Tuple, Literal

from .index_host_store import IndexHostStore

Expand Down Expand Up @@ -29,7 +29,8 @@
PodSpec as PodSpecModel,
PodSpecMetadataConfig,
)
from pinecone.models import ServerlessSpec, PodSpec, IndexList, CollectionList
from pinecone.core.openapi.shared import API_VERSION
from pinecone.models import ServerlessSpec, PodSpec, IndexModel, IndexList, CollectionList
from .langchain_import_warnings import _build_langchain_attribute_error_message

from pinecone.data import Index
Expand Down Expand Up @@ -230,6 +231,7 @@ def __init__(
config=self.config,
openapi_config=self.openapi_config,
pool_threads=pool_threads,
api_version=API_VERSION,
)

self.index_host_store = IndexHostStore()
Expand Down Expand Up @@ -259,7 +261,7 @@ def create_index(
spec: Union[Dict, ServerlessSpec, PodSpec],
metric: Optional[str] = "cosine",
timeout: Optional[int] = None,
deletion_protection: Optional[bool] = False,
deletion_protection: Optional[Literal["enabled", "disabled"]] = "disabled",
):
"""Creates a Pinecone index.
Expand All @@ -280,6 +282,7 @@ def create_index(
:type timeout: int, optional
:param timeout: Specify the number of seconds to wait until index gets ready. If None, wait indefinitely; if >=0, time out after this many seconds;
if -1, return immediately and do not wait. Default: None
:param deletion_protection: If enabled, the index cannot be deleted. If disabled, the index can be deleted. Default: "disabled"
### Creating a serverless index
Expand All @@ -293,7 +296,8 @@ def create_index(
name="my_index",
dimension=1536,
metric="cosine",
spec=ServerlessSpec(cloud="aws", region="us-west-2")
spec=ServerlessSpec(cloud="aws", region="us-west-2"),
deletion_protection="enabled"
)
```
Expand All @@ -312,7 +316,8 @@ def create_index(
spec=PodSpec(
environment="us-east1-gcp",
pod_type="p1.x1"
)
),
deletion_protection="enabled"
)
```
"""
Expand All @@ -322,10 +327,7 @@ def create_index(
def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
return {arg_name: val for arg_name, val in args if val is not None}

if deletion_protection:
dp = DeletionProtection("enabled")
else:
dp = DeletionProtection("disabled")
dp = DeletionProtection(deletion_protection)

if isinstance(spec, dict):
if "serverless" in spec:
Expand All @@ -345,16 +347,15 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
args_dict["metadata_config"] = PodSpecMetadataConfig(
indexed=args_dict["metadata_config"].get("indexed", None)
)
index_spec = IndexSpec(pod=PodSpecModel(**args_dict), deletion_protection=dp)
index_spec = IndexSpec(pod=PodSpecModel(**args_dict))
else:
raise ValueError("spec must contain either 'serverless' or 'pod' key")
elif isinstance(spec, ServerlessSpec):
index_spec = IndexSpec(
serverless=ServerlessSpecModel(
cloud=spec.cloud,
region=spec.region,
),
deletion_protection=dp
)
)
elif isinstance(spec, PodSpec):
args_dict = _parse_non_empty_args(
Expand All @@ -373,8 +374,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
environment=spec.environment,
pod_type=spec.pod_type,
**args_dict,
),
deletion_protection=dp
)
)
else:
raise TypeError("spec must be of type dict, ServerlessSpec, or PodSpec")
Expand All @@ -385,6 +385,7 @@ def _parse_non_empty_args(args: List[Tuple[str, Any]]) -> Dict[str, Any]:
dimension=dimension,
metric=metric,
spec=index_spec,
deletion_protection=dp,
),
)

Expand Down Expand Up @@ -463,7 +464,7 @@ def list_indexes(self) -> IndexList:
index name, dimension, metric, status, and spec.
:return: Returns an `IndexList` object, which is iterable and contains a
list of `IndexDescription` objects. It also has a convenience method `names()`
list of `IndexModel` objects. It also has a convenience method `names()`
which returns a list of index names.
```python
Expand Down Expand Up @@ -507,7 +508,7 @@ def describe_index(self, name: str):
"""Describes a Pinecone index.
:param name: the name of the index to describe.
:return: Returns an `IndexDescription` object
:return: Returns an `IndexModel` object
which gives access to properties such as the
index name, dimension, metric, host url, status,
and spec.
Expand Down Expand Up @@ -539,13 +540,16 @@ def describe_index(self, name: str):
host = description.host
self.index_host_store.set_host(self.config, name, host)

return description
return IndexModel(
description
)

def configure_index(
self,
name: str,
replicas: Optional[int] = None,
pod_type: Optional[str] = None,
deletion_protection: Optional[Literal["enabled", "disabled"]] = None,
):
"""This method is used to scale configuration fields for your pod-based Pinecone index.
Expand All @@ -570,15 +574,27 @@ def configure_index(
"""
api_instance = self.index_api
config_args: Dict[str, Any] = {}

description = self.describe_index(name=name)

if deletion_protection is None:
dp = DeletionProtection(description.deletion_protection)
else:
dp = DeletionProtection(deletion_protection)

pod_config_args: Dict[str, Any] = {}
if pod_type:
config_args.update(pod_type=pod_type)
pod_config_args.update(pod_type=pod_type)
if replicas:
config_args.update(replicas=replicas)
configure_index_request = ConfigureIndexRequest(
spec=ConfigureIndexRequestSpec(pod=ConfigureIndexRequestSpecPod(**config_args))
)
api_instance.configure_index(name, configure_index_request=configure_index_request)
pod_config_args.update(replicas=replicas)

if pod_config_args != {}:
spec = ConfigureIndexRequestSpec(pod=ConfigureIndexRequestSpecPod(**pod_config_args))
req = ConfigureIndexRequest(deletion_protection=dp, spec=spec)
else:
req = ConfigureIndexRequest(deletion_protection=dp)

api_instance.configure_index(name, configure_index_request=req)

def create_collection(self, name: str, source: str):
"""Create a collection from a pod-based index
Expand Down
6 changes: 2 additions & 4 deletions pinecone/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .index_description import (
IndexDescription,
IndexStatus,
ServerlessSpecDefinition,
PodSpecDefinition,
)
Expand All @@ -9,15 +7,15 @@
from .pod_spec import PodSpec
from .index_list import IndexList
from .collection_list import CollectionList
from .index_model import IndexModel

__all__ = [
"CollectionDescription",
"IndexDescription",
"IndexStatus",
"PodSpec",
"PodSpecDefinition",
"ServerlessSpec",
"ServerlessSpecDefinition",
"IndexList",
"CollectionList",
"IndexModel"
]
54 changes: 3 additions & 51 deletions pinecone/models/index_description.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from typing import NamedTuple, Dict, Optional, Union, Literal


class IndexStatus(NamedTuple):
state: str
ready: bool


PodKey = Literal["pod"]


class PodSpecDefinition(NamedTuple):
replicas: int
shards: int
Expand All @@ -17,51 +8,12 @@ class PodSpecDefinition(NamedTuple):
environment: str
metadata_config: Optional[Dict]


PodSpec = Dict[PodKey, PodSpecDefinition]

ServerlessKey = Literal["serverless"]


class ServerlessSpecDefinition(NamedTuple):
cloud: str
region: str

PodKey = Literal["pod"]
PodSpec = Dict[PodKey, PodSpecDefinition]

ServerlessKey = Literal["serverless"]
ServerlessSpec = Dict[ServerlessKey, ServerlessSpecDefinition]


class IndexDescription(NamedTuple):
"""
The description of an index. This object is returned from the `describe_index()` method.
"""

name: str
"""
The name of the index
"""

dimension: int
"""
The dimension of the index. This corresponds to the length of the vectors stored in the index.
"""

metric: str
"""
One of 'cosine', 'euclidean', or 'dotproduct'.
"""

host: str
"""
The endpoint you will use to connect to this index for data operations such as upsert and query.
"""

spec: Union[PodSpec, ServerlessSpec]
"""
The spec describes how the index is being deployed.
"""

status: IndexStatus
"""
Status includes information on whether the index is ready to accept data operations.
"""
15 changes: 15 additions & 0 deletions pinecone/models/index_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pinecone.core.openapi.control.models import IndexModel as OpenAPIIndexModel

class IndexModel:
def __init__(self, index: OpenAPIIndexModel):
self.index = index
self.deletion_protection = index.deletion_protection.value

def __str__(self):
return str(self.index)

def __repr__(self):
return repr(self.index)

def __getattr__(self, attr):
return getattr(self.index, attr)
40 changes: 40 additions & 0 deletions tests/integration/control/pod/test_deletion_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from pinecone import PodSpec

class TestDeletionProtection:
def test_deletion_protection(self, client, index_name, environment):
client.create_index(name=index_name, dimension=2, deletion_protection="enabled", spec=PodSpec(environment=environment))
desc = client.describe_index(index_name)
print(desc.deletion_protection)
print(desc.deletion_protection.__class__)
assert desc.deletion_protection == "enabled"

with pytest.raises(Exception) as e:
client.delete_index(index_name)
assert "Deletion protection is enabled for this index" in str(e.value)

client.configure_index(index_name, deletion_protection="disabled")
desc = client.describe_index(index_name)
assert desc.deletion_protection == "disabled"

client.delete_index(index_name)

def test_configure_index_with_deletion_protection(self, client, index_name, environment):
client.create_index(name=index_name, dimension=2, deletion_protection="enabled", spec=PodSpec(environment=environment))
desc = client.describe_index(index_name)
assert desc.deletion_protection == "enabled"

# Changing replicas only should not change deletion protection
client.configure_index(name=index_name, replicas=2)
desc = client.describe_index(index_name)
assert desc.spec.pod.replicas == 2
assert desc.deletion_protection == "enabled"

# Changing both replicas and delete protection in one shot
client.configure_index(name=index_name, replicas=3, deletion_protection="disabled")
desc = client.describe_index(index_name)
assert desc.spec.pod.replicas == 3
assert desc.deletion_protection == "disabled"

# Cleanup
client.delete_index(index_name)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def test_create_index(self, client, create_sl_index_params):
assert desc.name == name
assert desc.dimension == dimension
assert desc.metric == "cosine"
assert desc.deletion_protection is "disabled" # default value

@pytest.mark.parametrize("metric", ["cosine", "euclidean", "dotproduct"])
def test_create_index_with_metric(self, client, create_sl_index_params, metric):
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/control/serverless/test_deletion_protection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

class TestDeletionProtection:
def test_deletion_protection(self, client, create_sl_index_params):
name = create_sl_index_params["name"]
client.create_index(**create_sl_index_params, deletion_protection=True)
desc = client.describe_index(name)
assert desc.deletion_protection == "enabled"

with pytest.raises(Exception) as e:
client.delete_index(name)
assert "Deletion protection is enabled for this index" in str(e.value)

client.configure_index(name, deletion_protection=False)
desc = client.describe_index(name)
assert desc.deletion_protection == "disabled"

client.delete_index(name)
8 changes: 6 additions & 2 deletions tests/unit/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,19 @@ def test_passing_additional_headers(self):
for key, value in extras.items():
assert p.index_api.api_client.default_headers[key] == value
assert "User-Agent" in p.index_api.api_client.default_headers
assert len(p.index_api.api_client.default_headers) == 3
assert "X-Pinecone-API-Version" in p.index_api.api_client.default_headers
assert "header1" in p.index_api.api_client.default_headers
assert "header2" in p.index_api.api_client.default_headers
assert len(p.index_api.api_client.default_headers) == 4

def test_overwrite_useragent(self):
# This doesn't seem like a common use case, but we may want to allow this
# when embedding the client in other pinecone tools such as canopy.
extras = {"User-Agent": "test-user-agent"}
p = Pinecone(api_key="123-456-789", additional_headers=extras)
assert "X-Pinecone-API-Version" in p.index_api.api_client.default_headers
assert p.index_api.api_client.default_headers["User-Agent"] == "test-user-agent"
assert len(p.index_api.api_client.default_headers) == 1
assert len(p.index_api.api_client.default_headers) == 2

def test_set_source_tag_in_useragent(self):
p = Pinecone(api_key="123-456-789", source_tag="test_source_tag")
Expand Down

0 comments on commit 2c48823

Please sign in to comment.