From 75705fe47cd5b07ca7aac66febd175f7e099b4e8 Mon Sep 17 00:00:00 2001 From: Austin Noto-Moniz Date: Wed, 8 Jan 2025 10:35:08 -0500 Subject: [PATCH] [PNE-6511] Expose feature effects generation. At present, the backend will automatically generate feature effects upon successful training. However, that means old models lack them. Additionally, there may be circumstances where regenerating the values may be desired. The exposed endpoint will begin an asynchronous job which will generate new Shapley values, and replace the old ones. --- src/citrine/__version__.py | 2 +- src/citrine/informatics/feature_effects.py | 10 ++++++-- src/citrine/resources/predictor.py | 28 ++++++++++++++++++++++ tests/informatics/test_predictors.py | 18 ++++++++++++++ tests/resources/test_predictor.py | 22 +++++++++++++++-- tests/utils/factories.py | 3 +++ 6 files changed, 78 insertions(+), 5 deletions(-) diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index d1a7f1e0d..62ee17b83 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.12.0" +__version__ = "3.13.0" diff --git a/src/citrine/informatics/feature_effects.py b/src/citrine/informatics/feature_effects.py index 947ba021d..d0cd04d6e 100644 --- a/src/citrine/informatics/feature_effects.py +++ b/src/citrine/informatics/feature_effects.py @@ -51,7 +51,10 @@ class FeatureEffects(Resource): @classmethod def _pre_build(cls, data: dict) -> Dict: - shapley = data["result"] + shapley = data.get("result") + if not shapley: + return data + material_ids = shapley["materials"] outputs = [] @@ -73,4 +76,7 @@ def _pre_build(cls, data: dict) -> Dict: @property def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]: """Presents the feature effects as a dictionary by output.""" - return {output.output: output.feature_dict for output in self.outputs} + if self.outputs: + return {output.output: output.feature_dict for output in self.outputs} + else: + return {} diff --git a/src/citrine/resources/predictor.py b/src/citrine/resources/predictor.py index 08b69aaf0..f101a7366 100644 --- a/src/citrine/resources/predictor.py +++ b/src/citrine/resources/predictor.py @@ -196,6 +196,14 @@ def rename(self, entity = self.session.put_resource(path, json, version=self._api_version) return self.build(entity) + def generate_feature_effects(self, + uid: Union[UUID, str], + *, + version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + path = self._construct_path(uid, version, "shapley/generate") + self.session.put_resource(path, {}, version=self._api_version) + return self.get(uid, version=version) + def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER): """Predictor versions cannot be deleted at this time.""" msg = "Predictor versions cannot be deleted. Use 'archive_version' instead." @@ -578,6 +586,26 @@ def rename(self, uid, version=version, name=name, description=description ) + def generate_feature_effects_async(self, + uid: Union[UUID, str], + *, + version: Union[int, str]) -> GraphPredictor: + """Begin generation of feature effects. + + version can be any numerical version (which exists), "latest", or "most_recent". Although + note that this will fail if the predictor is not already trained. + + Feature effects are automatically generated for all new predictors after a successful + training as of the end of 2024. This call allows either regenerating those values, or + generating them for older predictors. + + This call just begins the process; generation usually takes a few minutes, but can take + much longer. As soon as the call completes, the old values will be inaccessible. To wait + for the generation to complete, and to retrieve the new values once they're ready, use + GraphPredictor.feature_effects. + """ + return self._versions_collection.generate_feature_effects(uid, version=version) + def delete(self, uid: Union[UUID, str]): """Predictors cannot be deleted at this time.""" msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead." diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index 570c96514..da98b2717 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -508,3 +508,21 @@ def test_feature_effects(graph_predictor): f"/versions/{graph_predictor.version}/shapley/query" assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] assert fe.as_dict == feature_effects_as_dict + + +def test_feature_effects_in_progress(graph_predictor): + feature_effects_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None) + + session = FakeSession() + session.set_response(feature_effects_response) + + graph_predictor._session = session + graph_predictor._project_id = uuid.uuid4() + + fe = graph_predictor.feature_effects + + expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \ + f"/versions/{graph_predictor.version}/shapley/query" + assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] + assert fe.outputs is None + assert fe.as_dict == {} diff --git a/tests/resources/test_predictor.py b/tests/resources/test_predictor.py index b877f1c62..e8fff5209 100644 --- a/tests/resources/test_predictor.py +++ b/tests/resources/test_predictor.py @@ -21,8 +21,8 @@ FakeSession ) from tests.utils.factories import ( - AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory, - TableDataSourceDataFactory + AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory, + FeatureEffectsResponseFactory, TableDataSourceDataFactory ) @@ -734,6 +734,7 @@ def test_rename_name_only(valid_graph_predictor_data): expected_payload = {"name": new_name, "description": None} assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + def test_rename_description_only(valid_graph_predictor_data): pred_id = valid_graph_predictor_data["id"] pred_version = valid_graph_predictor_data["metadata"]["version"] @@ -751,3 +752,20 @@ def test_rename_description_only(valid_graph_predictor_data): versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) expected_payload = {"name": None, "description": new_description} assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + + +def test_generate_shapley(valid_graph_predictor_data): + pred_id = valid_graph_predictor_data["id"] + pred_version = valid_graph_predictor_data["metadata"]["version"] + session = FakeSession() + pc = PredictorCollection(uuid.uuid4(), session) + + fe_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None) + session.set_responses(fe_response, valid_graph_predictor_data) + + versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + pred = pc.generate_feature_effects_async(pred_id, version=pred_version) + assert session.calls == [ + FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/shapley/generate", json={}), + FakeCall(method="GET", path=f"{versions_path}/{pred_version}") + ] diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 83cf1fee4..d6d5c989a 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -893,6 +893,9 @@ class FeatureEffectsResponseFactory(factory.DictFactory): def _expand_condensed(result_obj): + if not result_obj: + return None + whole_dict = {} for output, feature_dict in result_obj["outputs"].items(): whole_dict[output] = {}