Skip to content

Commit 1d2b4cb

Browse files
committed
[PNE-6367] Add support for feature effects.
The payload comes in as a condensed format, so it's expanded in order to constructed nested lists of objects for clarity and ease of use. Additionally, the hierarchy of data is flipped to more closely match how it will be used by our customers. To that end, 'as_dict' is provided to ease importing it into a pandas DataFrame for whatever processing and analysis the customer desires.
1 parent 455e12e commit 1d2b4cb

File tree

6 files changed

+149
-3
lines changed

6 files changed

+149
-3
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.11.6"
1+
__version__ = "3.12.0"
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from uuid import UUID
2+
3+
from citrine._rest.resource import Resource
4+
from citrine._serialization import properties
5+
6+
7+
class ShapleyMaterial(Resource):
8+
"""The feature effect of a material."""
9+
10+
material_id = properties.UUID('material_id', serializable=False)
11+
value = properties.Float('value', serializable=False)
12+
13+
14+
class ShapleyFeature(Resource):
15+
"""All feature effects for this feature by material."""
16+
17+
feature = properties.String('feature', serializable=False)
18+
materials = properties.List(properties.Object(ShapleyMaterial), 'materials',
19+
serializable=False)
20+
21+
@property
22+
def material_dict(self) -> dict[UUID, float]:
23+
"""Presents the feature's effects as a dictionary by material."""
24+
return {material.material_id: material.value for material in self.materials}
25+
26+
27+
class ShapleyOutput(Resource):
28+
"""All feature effects for this output by feature."""
29+
30+
output = properties.String('output', serializable=False)
31+
features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False)
32+
33+
@property
34+
def feature_dict(self) -> dict[str, dict[UUID, float]]:
35+
"""Presents the output's feature effects as a dictionary by feature."""
36+
return {feature.feature: feature.material_dict for feature in self.features}
37+
38+
39+
class FeatureEffects(Resource):
40+
"""Captures information about the feature effects associated with a predictor."""
41+
42+
predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
43+
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
44+
status = properties.String('metadata.status', serializable=False)
45+
failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason',
46+
serializable=False)
47+
48+
outputs = properties.List(properties.Object(ShapleyOutput), 'resultobj', serializable=False)
49+
50+
@classmethod
51+
def _pre_build(cls, data: dict) -> dict:
52+
shapley = data["result"]
53+
material_ids = shapley["materials"]
54+
55+
outputs = []
56+
for output, feature_dict in shapley["outputs"].items():
57+
features = []
58+
for feature, values in feature_dict.items():
59+
items = zip(material_ids, values)
60+
materials = [{"material_id": mid, "value": value} for mid, value in items]
61+
features.append({
62+
"feature": feature,
63+
"materials": materials
64+
})
65+
66+
outputs.append({"output": output, "features": features})
67+
68+
data["resultobj"] = outputs
69+
return data
70+
71+
@property
72+
def as_dict(self) -> dict[str, dict[str, dict[UUID, float]]]:
73+
"""Presents the feature effects as a dictionary by output."""
74+
return {output.output: output.feature_dict for output in self.outputs}

src/citrine/informatics/predictors/graph_predictor.py

+8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from citrine._session import Session
88
from citrine._utils.functions import format_escaped_url
99
from citrine.informatics.data_sources import DataSource
10+
from citrine.informatics.feature_effects import FeatureEffects
1011
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
1112
from citrine.informatics.predictors.single_prediction import SinglePrediction
1213
from citrine.informatics.predictors import PredictorNode, Predictor
@@ -113,6 +114,13 @@ def report(self):
113114
report_resource = ReportResource(self._project_id, self._session)
114115
return report_resource.get(predictor_id=self.uid, predictor_version=self.version)
115116

117+
@property
118+
def feature_effects(self):
119+
"""Retrieve the feature effects for all outputs in the predictor's training data.."""
120+
path = self._path() + '/shapley/query'
121+
response = self._session.post_resource(path, {}, version=self._api_version)
122+
return FeatureEffects.build(response)
123+
116124
def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
117125
"""Make a one-off prediction with this predictor."""
118126
path = self._path() + '/predict'

src/citrine/resources/table_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]):
8888
The query used to define the materials underpinning this table
8989
generation_algorithm: TableFromGemdQueryAlgorithm
9090
Which algorithm was used to generate the config based on the GemdQuery results
91+
9192
"""
9293

9394
# FIXME (DML): rename this (this is dependent on the server side)

tests/informatics/test_predictors.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Tests for citrine.informatics.predictors."""
2-
import uuid
3-
import pytest
42
import mock
3+
import pytest
4+
import uuid
5+
from random import random
56

67
from citrine.informatics.data_sources import GemTableDataSource
78
from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \
@@ -12,6 +13,10 @@
1213
from citrine.informatics.predictors.single_prediction import SinglePrediction
1314
from citrine.informatics.design_candidate import DesignMaterial
1415

16+
from tests.utils.factories import FeatureEffectsResponseFactory
17+
from tests.utils.session import FakeCall, FakeSession
18+
19+
1520
w = IntegerDescriptor("w", lower_bound=0, upper_bound=100)
1621
x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
1722
y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="")
@@ -485,3 +490,21 @@ def test_single_predict(graph_predictor):
485490
prediction_out = graph_predictor.predict(request)
486491
assert prediction_out.dump() == prediction_in.dump()
487492
assert session.post_resource.call_count == 1
493+
494+
495+
def test_feature_effects(graph_predictor):
496+
feature_effects_response = FeatureEffectsResponseFactory()
497+
feature_effects_as_dict = feature_effects_response.pop("_result_as_dict")
498+
499+
session = FakeSession()
500+
session.set_response(feature_effects_response)
501+
502+
graph_predictor._session = session
503+
graph_predictor._project_id = uuid.uuid4()
504+
505+
fe = graph_predictor.feature_effects
506+
507+
expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
508+
f"/versions/{graph_predictor.version}/shapley/query"
509+
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
510+
assert fe.as_dict == feature_effects_as_dict

tests/utils/factories.py

+40
Original file line numberDiff line numberDiff line change
@@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory):
859859
id = factory.Faker('uuid4')
860860
data = factory.SubFactory(AnalysisWorkflowDataDataFactory)
861861
metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory)
862+
863+
864+
class FeatureEffectsResponseResultFactory(factory.DictFactory):
865+
materials = factory.List([
866+
factory.Faker('uuid4', cast_to=None),
867+
factory.Faker('uuid4', cast_to=None),
868+
factory.Faker('uuid4', cast_to=None)
869+
])
870+
outputs = factory.Dict({
871+
"output1": factory.Dict({
872+
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
873+
}),
874+
"output2": factory.Dict({
875+
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]),
876+
"feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
877+
})
878+
})
879+
880+
class FeatureEffectsMetadataFactory(factory.DictFactory):
881+
predictor_id = factory.Faker('uuid4')
882+
predictor_version = factory.Faker('random_digit_not_null')
883+
created = factory.SubFactory(UserTimestampDataFactory)
884+
updated = factory.SubFactory(UserTimestampDataFactory)
885+
status = 'SUCCEEDED'
886+
887+
888+
class FeatureEffectsResponseFactory(factory.DictFactory):
889+
query = {} # Presently, querying from the SDK is not allowed.
890+
metadata = factory.SubFactory(FeatureEffectsMetadataFactory)
891+
result = factory.SubFactory(FeatureEffectsResponseResultFactory)
892+
_result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result))
893+
894+
895+
def _expand_condensed(result_obj):
896+
whole_dict = {}
897+
for output, feature_dict in result_obj["outputs"].items():
898+
whole_dict[output] = {}
899+
for feature, values in feature_dict.items():
900+
whole_dict[output][feature] = dict(zip(result_obj["materials"], values))
901+
return whole_dict

0 commit comments

Comments
 (0)