Skip to content

Commit 8a16228

Browse files
committed
[PNE-6367] Add support for feature effects.
The payload comes in as a condensed format, so it's expanded in order to constructed nested lists of objects for clarity and ease of use. Additionally, the hierarchy of data is flipped to more closely match how it will be used by our customers. To that end, 'as_dict' is provided to ease importing it into a pandas DataFrame for whatever processing and analysis the customer desires.
1 parent 455e12e commit 8a16228

File tree

6 files changed

+152
-4
lines changed

6 files changed

+152
-4
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.11.6"
1+
__version__ = "3.12.0"
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import Dict
2+
from uuid import UUID
3+
4+
from citrine._rest.resource import Resource
5+
from citrine._serialization import properties
6+
7+
8+
class ShapleyMaterial(Resource):
9+
"""The feature effect of a material."""
10+
11+
material_id = properties.UUID('material_id', serializable=False)
12+
value = properties.Float('value', serializable=False)
13+
14+
15+
class ShapleyFeature(Resource):
16+
"""All feature effects for this feature by material."""
17+
18+
feature = properties.String('feature', serializable=False)
19+
materials = properties.List(properties.Object(ShapleyMaterial), 'materials',
20+
serializable=False)
21+
22+
@property
23+
def material_dict(self) -> Dict[UUID, float]:
24+
"""Presents the feature's effects as a dictionary by material."""
25+
return {material.material_id: material.value for material in self.materials}
26+
27+
28+
class ShapleyOutput(Resource):
29+
"""All feature effects for this output by feature."""
30+
31+
output = properties.String('output', serializable=False)
32+
features = properties.List(properties.Object(ShapleyFeature), 'features', serializable=False)
33+
34+
@property
35+
def feature_dict(self) -> Dict[str, Dict[UUID, float]]:
36+
"""Presents the output's feature effects as a dictionary by feature."""
37+
return {feature.feature: feature.material_dict for feature in self.features}
38+
39+
40+
class FeatureEffects(Resource):
41+
"""Captures information about the feature effects associated with a predictor."""
42+
43+
predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
44+
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
45+
status = properties.String('metadata.status', serializable=False)
46+
failure_reason = properties.Optional(properties.String(), 'metadata.failure_reason',
47+
serializable=False)
48+
49+
outputs = properties.List(properties.Object(ShapleyOutput), 'resultobj', serializable=False)
50+
51+
@classmethod
52+
def _pre_build(cls, data: dict) -> Dict:
53+
shapley = data["result"]
54+
material_ids = shapley["materials"]
55+
56+
outputs = []
57+
for output, feature_dict in shapley["outputs"].items():
58+
features = []
59+
for feature, values in feature_dict.items():
60+
items = zip(material_ids, values)
61+
materials = [{"material_id": mid, "value": value} for mid, value in items]
62+
features.append({
63+
"feature": feature,
64+
"materials": materials
65+
})
66+
67+
outputs.append({"output": output, "features": features})
68+
69+
data["resultobj"] = outputs
70+
return data
71+
72+
@property
73+
def as_dict(self) -> Dict[str, Dict[str, Dict[UUID, float]]]:
74+
"""Presents the feature effects as a dictionary by output."""
75+
return {output.output: output.feature_dict for output in self.outputs}

src/citrine/informatics/predictors/graph_predictor.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from citrine._session import Session
88
from citrine._utils.functions import format_escaped_url
99
from citrine.informatics.data_sources import DataSource
10+
from citrine.informatics.feature_effects import FeatureEffects
1011
from citrine.informatics.predictors.single_predict_request import SinglePredictRequest
1112
from citrine.informatics.predictors.single_prediction import SinglePrediction
1213
from citrine.informatics.predictors import PredictorNode, Predictor
14+
from citrine.informatics.reports import Report
1315
from citrine.resources.report import ReportResource
1416

1517
__all__ = ['GraphPredictor']
@@ -104,7 +106,7 @@ def wrap_instance(predictor_data: dict) -> dict:
104106
}
105107

106108
@property
107-
def report(self):
109+
def report(self) -> Report:
108110
"""Fetch the predictor report."""
109111
if self.uid is None or self._session is None or self._project_id is None \
110112
or getattr(self, "version", None) is None:
@@ -113,6 +115,13 @@ def report(self):
113115
report_resource = ReportResource(self._project_id, self._session)
114116
return report_resource.get(predictor_id=self.uid, predictor_version=self.version)
115117

118+
@property
119+
def feature_effects(self) -> FeatureEffects:
120+
"""Retrieve the feature effects for all outputs in the predictor's training data.."""
121+
path = self._path() + '/shapley/query'
122+
response = self._session.post_resource(path, {}, version=self._api_version)
123+
return FeatureEffects.build(response)
124+
116125
def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
117126
"""Make a one-off prediction with this predictor."""
118127
path = self._path() + '/predict'

src/citrine/resources/table_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TableConfig(Resource["TableConfig"]):
8888
The query used to define the materials underpinning this table
8989
generation_algorithm: TableFromGemdQueryAlgorithm
9090
Which algorithm was used to generate the config based on the GemdQuery results
91+
9192
"""
9293

9394
# FIXME (DML): rename this (this is dependent on the server side)

tests/informatics/test_predictors.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Tests for citrine.informatics.predictors."""
2-
import uuid
3-
import pytest
42
import mock
3+
import pytest
4+
import uuid
5+
from random import random
56

67
from citrine.informatics.data_sources import GemTableDataSource
78
from citrine.informatics.descriptors import RealDescriptor, IntegerDescriptor, \
@@ -12,6 +13,10 @@
1213
from citrine.informatics.predictors.single_prediction import SinglePrediction
1314
from citrine.informatics.design_candidate import DesignMaterial
1415

16+
from tests.utils.factories import FeatureEffectsResponseFactory
17+
from tests.utils.session import FakeCall, FakeSession
18+
19+
1520
w = IntegerDescriptor("w", lower_bound=0, upper_bound=100)
1621
x = RealDescriptor("x", lower_bound=0, upper_bound=100, units="")
1722
y = RealDescriptor("y", lower_bound=0, upper_bound=100, units="")
@@ -485,3 +490,21 @@ def test_single_predict(graph_predictor):
485490
prediction_out = graph_predictor.predict(request)
486491
assert prediction_out.dump() == prediction_in.dump()
487492
assert session.post_resource.call_count == 1
493+
494+
495+
def test_feature_effects(graph_predictor):
496+
feature_effects_response = FeatureEffectsResponseFactory()
497+
feature_effects_as_dict = feature_effects_response.pop("_result_as_dict")
498+
499+
session = FakeSession()
500+
session.set_response(feature_effects_response)
501+
502+
graph_predictor._session = session
503+
graph_predictor._project_id = uuid.uuid4()
504+
505+
fe = graph_predictor.feature_effects
506+
507+
expected_path = f"/projects/{graph_predictor._project_id}/predictors/{graph_predictor.uid}" + \
508+
f"/versions/{graph_predictor.version}/shapley/query"
509+
assert session.calls == [FakeCall(method='POST', path=expected_path, json={})]
510+
assert fe.as_dict == feature_effects_as_dict

tests/utils/factories.py

+40
Original file line numberDiff line numberDiff line change
@@ -859,3 +859,43 @@ class AnalysisWorkflowEntityDataFactory(factory.DictFactory):
859859
id = factory.Faker('uuid4')
860860
data = factory.SubFactory(AnalysisWorkflowDataDataFactory)
861861
metadata = factory.SubFactory(AnalysisWorkflowMetadataDataFactory)
862+
863+
864+
class FeatureEffectsResponseResultFactory(factory.DictFactory):
865+
materials = factory.List([
866+
factory.Faker('uuid4', cast_to=None),
867+
factory.Faker('uuid4', cast_to=None),
868+
factory.Faker('uuid4', cast_to=None)
869+
])
870+
outputs = factory.Dict({
871+
"output1": factory.Dict({
872+
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
873+
}),
874+
"output2": factory.Dict({
875+
"feature1": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")]),
876+
"feature2": factory.List([factory.Faker("pyfloat"), factory.Faker("pyfloat"), factory.Faker("pyfloat")])
877+
})
878+
})
879+
880+
class FeatureEffectsMetadataFactory(factory.DictFactory):
881+
predictor_id = factory.Faker('uuid4')
882+
predictor_version = factory.Faker('random_digit_not_null')
883+
created = factory.SubFactory(UserTimestampDataFactory)
884+
updated = factory.SubFactory(UserTimestampDataFactory)
885+
status = 'SUCCEEDED'
886+
887+
888+
class FeatureEffectsResponseFactory(factory.DictFactory):
889+
query = {} # Presently, querying from the SDK is not allowed.
890+
metadata = factory.SubFactory(FeatureEffectsMetadataFactory)
891+
result = factory.SubFactory(FeatureEffectsResponseResultFactory)
892+
_result_as_dict = factory.LazyAttribute(lambda obj: _expand_condensed(obj.result))
893+
894+
895+
def _expand_condensed(result_obj):
896+
whole_dict = {}
897+
for output, feature_dict in result_obj["outputs"].items():
898+
whole_dict[output] = {}
899+
for feature, values in feature_dict.items():
900+
whole_dict[output][feature] = dict(zip(result_obj["materials"], values))
901+
return whole_dict

0 commit comments

Comments
 (0)