Skip to content

Commit 556ce5f

Browse files
authored
Merge pull request #982 from CitrineInformatics/feature/pne-6367-feature-effects
[PNE-6367] Add support for feature effects.
2 parents 455e12e + 623a307 commit 556ce5f

File tree

6 files changed

+153
-4
lines changed

6 files changed

+153
-4
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.11.6"
1+
__version__ = "3.12.0"
+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from typing import Dict
2+
from uuid import UUID
3+
4+
from citrine._rest.resource import Resource
5+
from citrine._serialization import properties
6+
7+
8+
class ShapleyMaterial(Resource):
9+
"""The feature effect of a material."""
10+
11+
material_id = properties.UUID('material_id', serializable=False)
12+
value = properties.Float('value', serializable=False)
13+
14+
15+
class ShapleyFeature(Resource):
16+
"""All feature effects for this feature by material."""
17+
18+
feature = properties.String('feature', serializable=False)
19+
materials = properties.List(properties.Object(ShapleyMaterial), 'materials',
20+
serializable=False)
21+
22+
@property
23+
def material_dict(self) -> Dict[UUID, float]:
24+
"""Presents the feature's effects as a dictionary by material."""
25+
return {material.material_id: material.value for material in self.materials}
26+
27+
28+
class ShapleyOutput(Resource):
29+
"""All feature effects for this output by feature."""
30+
31+
output = properties.String('output', serializable=False)
32+
features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False)
33+
34+
@property
35+
def feature_dict(self) -> Dict[str, Dict[UUID, float]]:
36+
"""Presents the output's feature effects as a dictionary by feature."""
37+
return {feature.feature: feature.material_dict for feature in self.features}
38+
39+
40+
class FeatureEffects(Resource):
41+
"""Captures information about the feature effects associated with a predictor."""
42+
43+
predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
44+
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
45+
status = properties.String('metadata.status', serializable=False)
46+
failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason',
47+
serializable=False)
48+
49+
outputs = properties.Optional(properties.List(properties.Object(ShapleyOutput)), 'resultobj',
50+
serializable=False)
51+
52+
@classmethod
53+
def _pre_build(cls, data: dict) -> Dict:
54+
shapley = data["result"]
55+
material_ids = shapley["materials"]
56+
57+
outputs = []
58+
for output, feature_dict in shapley["outputs"].items():
59+
features = []
60+
for feature, values in feature_dict.items():
61+
items = zip(material_ids, values)
62+
materials = [{"material_id": mid, "value": value} for mid, value in items]
63+
features.append({
64+
"feature": feature,
65+
"materials": materials
66+
})
67+
68+
outputs.append({"output": output, "features": features})
69+
70+
data["resultobj"] = outputs
71+
return data
72+
73+
@property
74+
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
75+
"""Presents the feature effects as a dictionary by output."""
76+
return {output.output: output.feature_dict for output in self.outputs}

src/citrine/informatics/predictors/graph_predictor.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from citrine._session import Session
88
from citrine._utils.functions import format_escaped_url
99
from citrine.informatics.data_sources import DataSource
10+
from citrine.informatics.feature_effects import FeatureEffects
1011
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
1112
from citrine.informatics.predictors.single_prediction import SinglePrediction
1213
from citrine.informatics.predictors import PredictorNode, Predictor
14+
from citrine.informatics.reports import Report
1315
from citrine.resources.report import ReportResource
1416

1517
__all__ = ['GraphPredictor']
@@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict:
104106
}
105107

106108
@property
107-
def report(self):
109+
def report(self) -> Report:
108110
"""Fetch the predictor report."""
109111
if self.uid is None or self._session is None or self._project_id is None \
110112
or getattr(self, "version", None) is None:
@@ -113,6 +115,13 @@ def report(self):
113115
report_resource = ReportResource(self._project_id, self._session)
114116
return report_resource.get(predictor_id=self.uid, predictor_version=self.version)
115117

118+
@property
119+
def feature_effects(self) -> FeatureEffects:
120+
"""Retrieve the feature effects for all outputs in the predictor's training data.."""
121+
path = self._path() + '/shapley/query'
122+
response = self._session.post_resource(path, {}, version=self._api_version)
123+
return FeatureEffects.build(response)
124+
116125
def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
117126
"""Make a one-off prediction with this predictor."""
118127
path = self._path() + '/predict'

src/citrine/resources/table_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]):
8888
The query used to define the materials underpinning this table
8989
generation_algorithm: TableFromGemdQueryAlgorithm
9090
Which algorithm was used to generate the config based on the GemdQuery results
91+
9192
"""
9293

9394
# FIXME (DML): rename this (this is dependent on the server side)

tests/informatics/test_predictors.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Tests for citrine.informatics.predictors."""
2-
import uuid
3-
import pytest
42
import mock
3+
import pytest
4+
import uuid
5+
from random import random
56

67
from citrine.informatics.data_sources import GemTableDataSource
78
from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \
@@ -12,6 +13,10 @@
1213
from citrine.informatics.predictors.single_prediction import SinglePrediction
1314
from citrine.informatics.design_candidate import DesignMaterial
1415

16+
from tests.utils.factories import FeatureEffectsResponseFactory
17+
from tests.utils.session import FakeCall, FakeSession
18+
19+
1520
w = IntegerDescriptor("w", lower_bound=0, upper_bound=100)
1621
x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
1722
y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="")
@@ -485,3 +490,21 @@ def test_single_predict(graph_predictor):
485490
prediction_out = graph_predictor.predict(request)
486491
assert prediction_out.dump() == prediction_in.dump()
487492
assert session.post_resource.call_count == 1
493+
494+
495+
def test_feature_effects(graph_predictor):
496+
feature_effects_response = FeatureEffectsResponseFactory()
497+
feature_effects_as_dict = feature_effects_response.pop("_result_as_dict")
498+
499+
session = FakeSession()
500+
session.set_response(feature_effects_response)
501+
502+
graph_predictor._session = session
503+
graph_predictor._project_id = uuid.uuid4()
504+
505+
fe = graph_predictor.feature_effects
506+
507+
expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
508+
f"/versions/{graph_predictor.version}/shapley/query"
509+
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
510+
assert fe.as_dict == feature_effects_as_dict

tests/utils/factories.py

+40
Original file line numberDiff line numberDiff line change
@@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory):
859859
id = factory.Faker('uuid4')
860860
data = factory.SubFactory(AnalysisWorkflowDataDataFactory)
861861
metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory)
862+
863+
864+
class FeatureEffectsResponseResultFactory(factory.DictFactory):
865+
materials = factory.List([
866+
factory.Faker('uuid4', cast_to=None),
867+
factory.Faker('uuid4', cast_to=None),
868+
factory.Faker('uuid4', cast_to=None)
869+
])
870+
outputs = factory.Dict({
871+
"output1": factory.Dict({
872+
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
873+
}),
874+
"output2": factory.Dict({
875+
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]),
876+
"feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
877+
})
878+
})
879+
880+
class FeatureEffectsMetadataFactory(factory.DictFactory):
881+
predictor_id = factory.Faker('uuid4')
882+
predictor_version = factory.Faker('random_digit_not_null')
883+
created = factory.SubFactory(UserTimestampDataFactory)
884+
updated = factory.SubFactory(UserTimestampDataFactory)
885+
status = 'SUCCEEDED'
886+
887+
888+
class FeatureEffectsResponseFactory(factory.DictFactory):
889+
query = {} # Presently, querying from the SDK is not allowed.
890+
metadata = factory.SubFactory(FeatureEffectsMetadataFactory)
891+
result = factory.SubFactory(FeatureEffectsResponseResultFactory)
892+
_result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result))
893+
894+
895+
def _expand_condensed(result_obj):
896+
whole_dict = {}
897+
for output, feature_dict in result_obj["outputs"].items():
898+
whole_dict[output] = {}
899+
for feature, values in feature_dict.items():
900+
whole_dict[output][feature] = dict(zip(result_obj["materials"], values))
901+
return whole_dict

0 commit comments

Comments
 (0)