Skip to content

Commit d4a2e39

Browse files
authored
Merge pull request #983 from CitrineInformatics/feature/pne-6511-shapley-generation
[PNE-6511] Expose feature effects generation.
2 parents 556ce5f + 75705fe commit d4a2e39

File tree

6 files changed

+78
-5
lines changed

6 files changed

+78
-5
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.12.0"
1+
__version__ = "3.13.0"

src/citrine/informatics/feature_effects.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ class FeatureEffects(Resource):
5151

5252
@classmethod
5353
def _pre_build(cls, data: dict) -> Dict:
54-
shapley = data["result"]
54+
shapley = data.get("result")
55+
if not shapley:
56+
return data
57+
5558
material_ids = shapley["materials"]
5659

5760
outputs = []
@@ -73,4 +76,7 @@ def _pre_build(cls, data: dict) -> Dict:
7376
@property
7477
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
7578
"""Presents the feature effects as a dictionary by output."""
76-
return {output.output: output.feature_dict for output in self.outputs}
79+
if self.outputs:
80+
return {output.output: output.feature_dict for output in self.outputs}
81+
else:
82+
return {}

src/citrine/resources/predictor.py

+28
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,14 @@ def rename(self,
196196
entity = self.session.put_resource(path, json, version=self._api_version)
197197
return self.build(entity)
198198

199+
def generate_feature_effects(self,
200+
uid: Union[UUID, str],
201+
*,
202+
version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor:
203+
path = self._construct_path(uid, version, "shapley/generate")
204+
self.session.put_resource(path, {}, version=self._api_version)
205+
return self.get(uid, version=version)
206+
199207
def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER):
200208
"""Predictor versions cannot be deleted at this time."""
201209
msg = "Predictor versions cannot be deleted. Use 'archive_version' instead."
@@ -578,6 +586,26 @@ def rename(self,
578586
uid, version=version, name=name, description=description
579587
)
580588

589+
def generate_feature_effects_async(self,
590+
uid: Union[UUID, str],
591+
*,
592+
version: Union[int, str]) -> GraphPredictor:
593+
"""Begin generation of feature effects.
594+
595+
version can be any numerical version (which exists), "latest", or "most_recent". Although
596+
note that this will fail if the predictor is not already trained.
597+
598+
Feature effects are automatically generated for all new predictors after a successful
599+
training as of the end of 2024. This call allows either regenerating those values, or
600+
generating them for older predictors.
601+
602+
This call just begins the process; generation usually takes a few minutes, but can take
603+
much longer. As soon as the call completes, the old values will be inaccessible. To wait
604+
for the generation to complete, and to retrieve the new values once they're ready, use
605+
GraphPredictor.feature_effects.
606+
"""
607+
return self._versions_collection.generate_feature_effects(uid, version=version)
608+
581609
def delete(self, uid: Union[UUID, str]):
582610
"""Predictors cannot be deleted at this time."""
583611
msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead."

tests/informatics/test_predictors.py

+18
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,21 @@ def test_feature_effects(graph_predictor):
508508
f"/versions/{graph_predictor.version}/shapley/query"
509509
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
510510
assert fe.as_dict == feature_effects_as_dict
511+
512+
513+
def test_feature_effects_in_progress(graph_predictor):
514+
feature_effects_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None)
515+
516+
session = FakeSession()
517+
session.set_response(feature_effects_response)
518+
519+
graph_predictor._session = session
520+
graph_predictor._project_id = uuid.uuid4()
521+
522+
fe = graph_predictor.feature_effects
523+
524+
expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
525+
f"/versions/{graph_predictor.version}/shapley/query"
526+
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
527+
assert fe.outputs is None
528+
assert fe.as_dict == {}

tests/resources/test_predictor.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
FakeSession
2222
)
2323
from tests.utils.factories import (
24-
AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory,
25-
TableDataSourceDataFactory
24+
AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory,
25+
FeatureEffectsResponseFactory, TableDataSourceDataFactory
2626
)
2727

2828

@@ -734,6 +734,7 @@ def test_rename_name_only(valid_graph_predictor_data):
734734
expected_payload = {"name": new_name, "description": None}
735735
assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)]
736736

737+
737738
def test_rename_description_only(valid_graph_predictor_data):
738739
pred_id = valid_graph_predictor_data["id"]
739740
pred_version = valid_graph_predictor_data["metadata"]["version"]
@@ -751,3 +752,20 @@ def test_rename_description_only(valid_graph_predictor_data):
751752
versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
752753
expected_payload = {"name": None, "description": new_description}
753754
assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)]
755+
756+
757+
def test_generate_shapley(valid_graph_predictor_data):
758+
pred_id = valid_graph_predictor_data["id"]
759+
pred_version = valid_graph_predictor_data["metadata"]["version"]
760+
session = FakeSession()
761+
pc = PredictorCollection(uuid.uuid4(), session)
762+
763+
fe_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None)
764+
session.set_responses(fe_response, valid_graph_predictor_data)
765+
766+
versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
767+
pred = pc.generate_feature_effects_async(pred_id, version=pred_version)
768+
assert session.calls == [
769+
FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/shapley/generate", json={}),
770+
FakeCall(method="GET", path=f"{versions_path}/{pred_version}")
771+
]

tests/utils/factories.py

+3
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,9 @@ class FeatureEffectsResponseFactory(factory.DictFactory):
893893

894894

895895
def _expand_condensed(result_obj):
896+
if not result_obj:
897+
return None
898+
896899
whole_dict = {}
897900
for output, feature_dict in result_obj["outputs"].items():
898901
whole_dict[output] = {}

0 commit comments

Comments
 (0)