Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PNE-416] Use v4 listing endpoints. #947

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.2.15"
__version__ = "3.3.0"
18 changes: 12 additions & 6 deletions src/citrine/_rest/pageable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def _fetch_page(self,
per_page: Optional[int] = None,
json_body: Optional[dict] = None,
additional_params: Optional[dict] = None,
*,
version: Optional[str] = None
) -> Tuple[Iterable[dict], str]:
"""
Fetch visible elements. This does not handle pagination.
Expand Down Expand Up @@ -58,6 +60,9 @@ def _fetch_page(self,
}
additional_params: dict, optional
A dict that allows extra parameters to be added to the request parameters
version: str, optional
A string denoting which version of the underlying API endpoint will be called. Defaults
to the collection's API version.

Returns
-------
Expand All @@ -68,15 +73,16 @@ def _fetch_page(self,

"""
# To avoid setting defaults -> reduce mutation risk, and to make more extensible
path = self._get_path() if path is None else path
fetch_func = self.session.get_resource if fetch_func is None else fetch_func
json_body = {} if json_body is None else json_body
path = path or self._get_path()
fetch_func = fetch_func or self.session.get_resource
json_body = json_body or {}

module_type = getattr(self, '_module_type', None)
params = self._page_params(page, per_page, module_type)
params = self._page_params(page, per_page)
params.update(additional_params or {})

data = fetch_func(path, params=params, version=self._api_version, **json_body)
version = version or self._api_version

data = fetch_func(path, params=params, version=version, **json_body)

try:
next_uri = data.get('next', "")
Expand Down
27 changes: 26 additions & 1 deletion src/citrine/resources/design_space.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Resources that represent collections of design spaces."""
from typing import Optional, TypeVar, Union
from functools import partial
from typing import Iterable, Optional, TypeVar, Union
from uuid import UUID

from gemd.enumeration.base_enumeration import BaseEnumeration
Expand Down Expand Up @@ -127,6 +128,30 @@ def restore(self, uid: Union[UUID, str]) -> DesignSpace:
entity = self.session.put_resource(url, {}, version=self._api_version)
return self.build(entity)

def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None):
filters = {}
if archived is not None:
filters["archived"] = archived

fetcher = partial(self._fetch_page,
fetch_func=partial(self.session.get_resource, version="v4"),
additional_params=filters)
return self._paginator.paginate(page_fetcher=fetcher,
collection_builder=self._build_collection_elements,
per_page=per_page)

def list_all(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
"""List the most recent version of all design spaces."""
return self._list_base(per_page=per_page)

def list(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
"""List the most recent version of all non-archived design spaces."""
return self._list_base(per_page=per_page, archived=False)

def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
"""List the most recent version of all archived predictors."""
return self._list_base(per_page=per_page, archived=True)

def create_default(self,
*,
predictor_id: UUID,
Expand Down
31 changes: 24 additions & 7 deletions src/citrine/resources/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,20 +367,37 @@ def restore(self, uid: Union[UUID, str]):
raise NotImplementedError("The restore() method is no longer supported. You most likely "
"want restore_root(), or possibly restore_version().")

def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None):
filters = {}
if archived is not None:
filters["archived"] = archived

fetcher = partial(self._fetch_page,
additional_params=filters,
version="v4")
return self._paginator.paginate(page_fetcher=fetcher,
collection_builder=self._build_collection_elements,
per_page=per_page)

def list_all(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
"""List the most recent version of all predictors."""
return self._list_base(per_page=per_page)

def list(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
"""List the most recent version of all non-archived predictors."""
return self._list_base(per_page=per_page, archived=False)

def list_archived(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
"""List the most recent version of all archived predictors."""
return self._list_base(per_page=per_page, archived=True)

def list_versions(self,
uid: Union[UUID, str] = None,
*,
per_page: int = 100) -> Iterable[GraphPredictor]:
"""List all non-archived versions of the given Predictor."""
return self._versions_collection.list(uid, per_page=per_page)

def list_archived(self, *, per_page: int = 20) -> Iterable[GraphPredictor]:
"""List archived Predictors."""
fetcher = partial(self._fetch_page, additional_params={"filter": "archived eq 'true'"})
return self._paginator.paginate(page_fetcher=fetcher,
collection_builder=self._build_collection_elements,
per_page=per_page)

def list_archived_versions(self,
uid: Union[UUID, str] = None,
*,
Expand Down
40 changes: 39 additions & 1 deletion tests/resources/test_design_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,45 @@ def test_list_design_spaces(valid_formulation_design_space_data, valid_enumerate

# Then
expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id),
params={'per_page': 20, 'page': 1})
params={'per_page': 20, 'page': 1, 'archived': False})
assert 1 == session.num_calls, session.calls
assert expected_call == session.calls[0]
assert len(design_spaces) == 2


def test_list_all_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data):
# Given
session = FakeSession()
collection = DesignSpaceCollection(uuid.uuid4(), session)
session.set_response({
'response': [valid_formulation_design_space_data, valid_enumerated_design_space_data]
})

# When
design_spaces = list(collection.list_all(per_page=25))

# Then
expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id),
params={'per_page': 25, 'page': 1})
assert 1 == session.num_calls, session.calls
assert expected_call == session.calls[0]
assert len(design_spaces) == 2


def test_list_archived_design_spaces(valid_formulation_design_space_data, valid_enumerated_design_space_data):
# Given
session = FakeSession()
collection = DesignSpaceCollection(uuid.uuid4(), session)
session.set_response({
'response': [valid_formulation_design_space_data, valid_enumerated_design_space_data]
})

# When
design_spaces = list(collection.list_archived(per_page=25))

# Then
expected_call = FakeCall(method='GET', path='/projects/{}/design-spaces'.format(collection.project_id),
params={'per_page': 25, 'page': 1, 'archived': True})
assert 1 == session.num_calls, session.calls
assert expected_call == session.calls[0]
assert len(design_spaces) == 2
Expand Down
64 changes: 45 additions & 19 deletions tests/resources/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,70 @@ def test_train(valid_graph_predictor_data):
assert session.calls == expected_calls


def test_list_predictors(valid_graph_predictor_data, valid_graph_predictor_data_empty):
def test_list(valid_graph_predictor_data, valid_graph_predictor_data_empty):
# Given
session = FakeSession()
collection = PredictorCollection(uuid.uuid4(), session)
session.set_responses(
{
'response': [valid_graph_predictor_data, valid_graph_predictor_data_empty],
'next': ''
'page': 1,
'per_page': 25
},
basic_predictor_report_data,
basic_predictor_report_data
)

# When
predictors = list(collection.list(per_page=20))
predictors = list(collection.list(per_page=25))

# Then
expected_call = FakeCall(method='GET', path='/projects/{}/predictors'.format(collection.project_id),
params={'per_page': 20, 'page': 1})
expected_call = FakeCall(method='GET',
path='/projects/{}/predictors'.format(collection.project_id),
params={'per_page': 25, 'page': 1, 'archived': False})
assert 1 == session.num_calls, session.calls
assert expected_call == session.calls[0]
assert len(predictors) == 2


def test_list_all(valid_graph_predictor_data, valid_graph_predictor_data_empty):
# Given
session = FakeSession()
collection = PredictorCollection(uuid.uuid4(), session)
session.set_responses(
{'response': [valid_graph_predictor_data, valid_graph_predictor_data_empty]},
basic_predictor_report_data,
basic_predictor_report_data
)

# When
predictors = list(collection.list_all(per_page=25))

# Then
expected_call = FakeCall(method='GET',
path='/projects/{}/predictors'.format(collection.project_id),
params={'per_page': 25, 'page': 1})
assert 1 == session.num_calls, session.calls
assert expected_call == session.calls[0]
assert len(predictors) == 2


def test_list_archived(valid_graph_predictor_data):
# Given
session = FakeSession()
session.set_response({'response': [valid_graph_predictor_data]})
pc = PredictorCollection(uuid.uuid4(), session)

# When
list(pc.list_archived())

# Then
assert session.num_calls == 1
assert session.last_call == FakeCall(method='GET',
path=f"/projects/{pc.project_id}/predictors",
params={'per_page': 20, 'page': 1, 'archived': True})


def test_get(valid_graph_predictor_data):
# Given
session = FakeSession()
Expand Down Expand Up @@ -445,20 +485,6 @@ def test_returned_predictor(valid_graph_predictor_data):
assert isinstance(result.predictors[-1], AutoMLPredictor)


def test_predictor_list_archived(valid_graph_predictor_data):
# Given
session = FakeSession()
session.set_response({'response': [valid_graph_predictor_data]})
pc = PredictorCollection(uuid.uuid4(), session)

# When
list(pc.list_archived())

# Then
assert session.num_calls == 1
assert session.last_call == FakeCall(method='GET', path=f"/projects/{pc.project_id}/predictors", params={"filter": "archived eq 'true'", 'per_page': 20, 'page': 1})


def test_list_versions(valid_graph_predictor_data):
# Given
session = FakeSession()
Expand Down