Skip to content

Commit 7eac254

Browse files
authored
Merge pull request #985 from CitrineInformatics/revert-shapley-generate
Revert "Merge pull request #983 from CitrineInformatics/feature/pne-6…
2 parents ae8701f + 1e6aa04 commit 7eac254

File tree

3 files changed

+3
-49
lines changed

3 files changed

+3
-49
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.14.0"
1+
__version__ = "3.15.0"

src/citrine/resources/predictor.py

-28
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,6 @@ def rename(self,
196196
entity = self.session.put_resource(path, json, version=self._api_version)
197197
return self.build(entity)
198198

199-
def generate_feature_effects(self,
200-
uid: Union[UUID, str],
201-
*,
202-
version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor:
203-
path = self._construct_path(uid, version, "shapley/generate")
204-
self.session.put_resource(path, {}, version=self._api_version)
205-
return self.get(uid, version=version)
206-
207199
def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER):
208200
"""Predictor versions cannot be deleted at this time."""
209201
msg = "Predictor versions cannot be deleted. Use 'archive_version' instead."
@@ -586,26 +578,6 @@ def rename(self,
586578
uid, version=version, name=name, description=description
587579
)
588580

589-
def generate_feature_effects_async(self,
590-
uid: Union[UUID, str],
591-
*,
592-
version: Union[int, str]) -> GraphPredictor:
593-
"""Begin generation of feature effects.
594-
595-
version can be any numerical version (which exists), "latest", or "most_recent". Although
596-
note that this will fail if the predictor is not already trained.
597-
598-
Feature effects are automatically generated for all new predictors after a successful
599-
training as of the end of 2024. This call allows either regenerating those values, or
600-
generating them for older predictors.
601-
602-
This call just begins the process; generation usually takes a few minutes, but can take
603-
much longer. As soon as the call completes, the old values will be inaccessible. To wait
604-
for the generation to complete, and to retrieve the new values once they're ready, use
605-
GraphPredictor.feature_effects.
606-
"""
607-
return self._versions_collection.generate_feature_effects(uid, version=version)
608-
609581
def delete(self, uid: Union[UUID, str]):
610582
"""Predictors cannot be deleted at this time."""
611583
msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead."

tests/resources/test_predictor.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
FakeSession
2222
)
2323
from tests.utils.factories import (
24-
AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory,
25-
FeatureEffectsResponseFactory, TableDataSourceDataFactory
24+
AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory,
25+
TableDataSourceDataFactory
2626
)
2727

2828

@@ -734,7 +734,6 @@ def test_rename_name_only(valid_graph_predictor_data):
734734
expected_payload = {"name": new_name, "description": None}
735735
assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)]
736736

737-
738737
def test_rename_description_only(valid_graph_predictor_data):
739738
pred_id = valid_graph_predictor_data["id"]
740739
pred_version = valid_graph_predictor_data["metadata"]["version"]
@@ -752,20 +751,3 @@ def test_rename_description_only(valid_graph_predictor_data):
752751
versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
753752
expected_payload = {"name": None, "description": new_description}
754753
assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)]
755-
756-
757-
def test_generate_shapley(valid_graph_predictor_data):
758-
pred_id = valid_graph_predictor_data["id"]
759-
pred_version = valid_graph_predictor_data["metadata"]["version"]
760-
session = FakeSession()
761-
pc = PredictorCollection(uuid.uuid4(), session)
762-
763-
fe_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None)
764-
session.set_responses(fe_response, valid_graph_predictor_data)
765-
766-
versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
767-
pred = pc.generate_feature_effects_async(pred_id, version=pred_version)
768-
assert session.calls == [
769-
FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/shapley/generate", json={}),
770-
FakeCall(method="GET", path=f"{versions_path}/{pred_version}")
771-
]

0 commit comments

Comments
 (0)