diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index d1a7f1e0d..62ee17b83 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.12.0" +__version__ = "3.13.0" diff --git a/src/citrine/informatics/feature_effects.py b/src/citrine/informatics/feature_effects.py index 947ba021d..d0cd04d6e 100644 --- a/src/citrine/informatics/feature_effects.py +++ b/src/citrine/informatics/feature_effects.py @@ -51,7 +51,10 @@ class FeatureEffects(Resource): @classmethod def _pre_build(cls, data: dict) -> Dict: - shapley = data["result"] + shapley = data.get("result") + if not shapley: + return data + material_ids = shapley["materials"] outputs = [] @@ -73,4 +76,7 @@ def _pre_build(cls, data: dict) -> Dict: @property def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]: """Presents the feature effects as a dictionary by output.""" - return {output.output: output.feature_dict for output in self.outputs} + if self.outputs: + return {output.output: output.feature_dict for output in self.outputs} + else: + return {} diff --git a/src/citrine/resources/predictor.py b/src/citrine/resources/predictor.py index 08b69aaf0..f101a7366 100644 --- a/src/citrine/resources/predictor.py +++ b/src/citrine/resources/predictor.py @@ -196,6 +196,14 @@ def rename(self, entity = self.session.put_resource(path, json, version=self._api_version) return self.build(entity) + def generate_feature_effects(self, + uid: Union[UUID, str], + *, + version: Union[int, str] = MOST_RECENT_VER) -> GraphPredictor: + path = self._construct_path(uid, version, "shapley/generate") + self.session.put_resource(path, {}, version=self._api_version) + return self.get(uid, version=version) + def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER): """Predictor versions cannot be deleted at this time.""" msg = "Predictor versions cannot be deleted. Use 'archive_version' instead." @@ -578,6 +586,26 @@ def rename(self, uid, version=version, name=name, description=description ) + def generate_feature_effects_async(self, + uid: Union[UUID, str], + *, + version: Union[int, str]) -> GraphPredictor: + """Begin generation of feature effects. + + version can be any numerical version (which exists), "latest", or "most_recent". Although + note that this will fail if the predictor is not already trained. + + Feature effects are automatically generated for all new predictors after a successful + training as of the end of 2024. This call allows either regenerating those values, or + generating them for older predictors. + + This call just begins the process; generation usually takes a few minutes, but can take + much longer. As soon as the call completes, the old values will be inaccessible. To wait + for the generation to complete, and to retrieve the new values once they're ready, use + GraphPredictor.feature_effects. + """ + return self._versions_collection.generate_feature_effects(uid, version=version) + def delete(self, uid: Union[UUID, str]): """Predictors cannot be deleted at this time.""" msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead." diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index 570c96514..da98b2717 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -508,3 +508,21 @@ def test_feature_effects(graph_predictor): f"/versions/{graph_predictor.version}/shapley/query" assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] assert fe.as_dict == feature_effects_as_dict + + +def test_feature_effects_in_progress(graph_predictor): + feature_effects_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None) + + session = FakeSession() + session.set_response(feature_effects_response) + + graph_predictor._session = session + graph_predictor._project_id = uuid.uuid4() + + fe = graph_predictor.feature_effects + + expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \ + f"/versions/{graph_predictor.version}/shapley/query" + assert session.calls == [FakeCall(method='POST', path=expected_path, json={})] + assert fe.outputs is None + assert fe.as_dict == {} diff --git a/tests/resources/test_predictor.py b/tests/resources/test_predictor.py index b877f1c62..e8fff5209 100644 --- a/tests/resources/test_predictor.py +++ b/tests/resources/test_predictor.py @@ -21,8 +21,8 @@ FakeSession ) from tests.utils.factories import ( - AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory, - TableDataSourceDataFactory + AsyncDefaultPredictorResponseFactory, AsyncDefaultPredictorResponseMetadataFactory, + FeatureEffectsResponseFactory, TableDataSourceDataFactory ) @@ -734,6 +734,7 @@ def test_rename_name_only(valid_graph_predictor_data): expected_payload = {"name": new_name, "description": None} assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + def test_rename_description_only(valid_graph_predictor_data): pred_id = valid_graph_predictor_data["id"] pred_version = valid_graph_predictor_data["metadata"]["version"] @@ -751,3 +752,20 @@ def test_rename_description_only(valid_graph_predictor_data): versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) expected_payload = {"name": None, "description": new_description} assert session.calls == [FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/rename", json=expected_payload)] + + +def test_generate_shapley(valid_graph_predictor_data): + pred_id = valid_graph_predictor_data["id"] + pred_version = valid_graph_predictor_data["metadata"]["version"] + session = FakeSession() + pc = PredictorCollection(uuid.uuid4(), session) + + fe_response = FeatureEffectsResponseFactory(metadata__status="INPROGRESS", result=None) + session.set_responses(fe_response, valid_graph_predictor_data) + + versions_path = _PredictorVersionCollection._path_template.format(project_id=pc.project_id, uid=pred_id) + pred = pc.generate_feature_effects_async(pred_id, version=pred_version) + assert session.calls == [ + FakeCall(method="PUT", path=f"{versions_path}/{pred_version}/shapley/generate", json={}), + FakeCall(method="GET", path=f"{versions_path}/{pred_version}") + ] diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 83cf1fee4..d6d5c989a 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -893,6 +893,9 @@ class FeatureEffectsResponseFactory(factory.DictFactory): def _expand_condensed(result_obj): + if not result_obj: + return None + whole_dict = {} for output, feature_dict in result_obj["outputs"].items(): whole_dict[output] = {}