Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PNE-6511] Expose feature effects generation. #983

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.12.0"
__version__ = "3.13.0"
10 changes: 8 additions & 2 deletions src/citrine/informatics/feature_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ class FeatureEffects(Resource):

@classmethod
def _pre_build(cls, data: dict) -> Dict:
shapley = data["result"]
shapley = data.get("result")
if not shapley:
return data
Comment on lines +54 to +56
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a bug in the existing code which caused it to blow up if there was no result field, which would happen if either Shapley hadn't been run, or if the generation was in progress.


material_ids = shapley["materials"]

outputs = []
Expand All @@ -73,4 +76,7 @@ def _pre_build(cls, data: dict) -> Dict:
@property
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
"""Presents the feature effects as a dictionary by output."""
return {output.output: output.feature_dict for output in self.outputs}
if self.outputs:
return {output.output: output.feature_dict for output in self.outputs}
else:
return {}
28 changes: 28 additions & 0 deletions src/citrine/resources/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ def rename(self,
entity = self.session.put_resource(path, json, version=self._api_version)
return self.build(entity)

def generate_feature_effects(self,
uid: Union[UUID, str],
*,
version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor:
path = self._construct_path(uid, version, "shapley/generate")
self.session.put_resource(path, {}, version=self._api_version)
return self.get(uid, version=version)
Comment on lines +204 to +205
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FeatureEffects object returned by the generate call will always be in progress and have no results yet. Additionally, as currently constructed, it doesn't provide an easy way to get the updated status: you'd have to retrieve the corresponding predictor, then call .feature_effects on it. So instead, this does the get automatically.


def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER):
"""Predictor versions cannot be deleted at this time."""
msg = "Predictor versions cannot be deleted. Use 'archive_version' instead."
Expand Down Expand Up @@ -578,6 +586,26 @@ def rename(self,
uid, version=version, name=name, description=description
)

def generate_feature_effects_async(self,
uid: Union[UUID, str],
*,
version: Union[int, str]) -> GraphPredictor:
"""Begin generation of feature effects.

version can be any numerical version (which exists), "latest", or "most_recent". Although
note that this will fail if the predictor is not already trained.

Feature effects are automatically generated for all new predictors after a successful
training as of the end of 2024. This call allows either regenerating those values, or
generating them for older predictors.

This call just begins the process; generation usually takes a few minutes, but can take
much longer. As soon as the call completes, the old values will be inaccessible. To wait
for the generation to complete, and to retrieve the new values once they're ready, use
GraphPredictor.feature_effects.
"""
return self._versions_collection.generate_feature_effects(uid, version=version)

def delete(self, uid: Union[UUID, str]):
"""Predictors cannot be deleted at this time."""
msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead."
Expand Down
18 changes: 18 additions & 0 deletions tests/informatics/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,21 @@ def test_feature_effects(graph_predictor):
f"/versions/{graph_predictor.version}/shapley/query"
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
assert fe.as_dict == feature_effects_as_dict


def test_feature_effects_in_progress(graph_predictor):
feature_effects_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None)

session = FakeSession()
session.set_response(feature_effects_response)

graph_predictor._session = session
graph_predictor._project_id = uuid.uuid4()

fe = graph_predictor.feature_effects

expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
f"/versions/{graph_predictor.version}/shapley/query"
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
assert fe.outputs is None
assert fe.as_dict == {}
22 changes: 20 additions & 2 deletions tests/resources/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
FakeSession
)
from tests.utils.factories import (
AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory,
TableDataSourceDataFactory
AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory,
FeatureEffectsResponseFactory, TableDataSourceDataFactory
)


Expand Down Expand Up @@ -734,6 +734,7 @@ def test_rename_name_only(valid_graph_predictor_data):
expected_payload = {"name": new_name, "description": None}
assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)]


def test_rename_description_only(valid_graph_predictor_data):
pred_id = valid_graph_predictor_data["id"]
pred_version = valid_graph_predictor_data["metadata"]["version"]
Expand All @@ -751,3 +752,20 @@ def test_rename_description_only(valid_graph_predictor_data):
versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
expected_payload = {"name": None, "description": new_description}
assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)]


def test_generate_shapley(valid_graph_predictor_data):
pred_id = valid_graph_predictor_data["id"]
pred_version = valid_graph_predictor_data["metadata"]["version"]
session = FakeSession()
pc = PredictorCollection(uuid.uuid4(), session)

fe_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None)
session.set_responses(fe_response, valid_graph_predictor_data)

versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id)
pred = pc.generate_feature_effects_async(pred_id, version=pred_version)
assert session.calls == [
FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/shapley/generate", json={}),
FakeCall(method="GET", path=f"{versions_path}/{pred_version}")
]
3 changes: 3 additions & 0 deletions tests/utils/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,9 @@ class FeatureEffectsResponseFactory(factory.DictFactory):


def _expand_condensed(result_obj):
if not result_obj:
return None

whole_dict = {}
for output, feature_dict in result_obj["outputs"].items():
whole_dict[output] = {}
Expand Down
Loading