Skip to content

Commit 14025e5

Browse files
committed
Table and model update logic
1 parent e57e092 commit 14025e5

File tree

8 files changed

+222
-3
lines changed

8 files changed

+222
-3
lines changed

src/citrine/__version__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.11.0"
1+
__version__ = "3.11.1"

src/citrine/informatics/predictors/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# flake8: noqa
22
from .predictor import *
33
from .node import *
4+
from .attribute_accumulation_predictor import *
45
from .expression_predictor import *
56
from .graph_predictor import *
67
from .ingredient_fractions_predictor import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from typing import List
2+
3+
from citrine._rest.resource import Resource
4+
from citrine._serialization import properties as _properties
5+
from citrine.informatics.descriptors import Descriptor
6+
from citrine.informatics.predictors import PredictorNode
7+
8+
__all__ = ['AttributeAccumulationPredictor']
9+
10+
11+
class AttributeAccumulationPredictor(Resource["AttributeAccumulationPredictor"], PredictorNode):
12+
"""A predictor that computes an output from an expression and set of bounded inputs.
13+
14+
For a discussion of expression syntax and a list of allowed symbols,
15+
please see the :ref:`documentation<Attribute Accumulation>`.
16+
17+
Parameters
18+
----------
19+
name: str
20+
name of the configuration
21+
description: str
22+
the description of the predictor
23+
attributes: List[Descriptor]
24+
the attributes that are accumulated from ancestor nodes
25+
26+
"""
27+
28+
attributes = _properties.List(_properties.Object(Descriptor), 'attributes')
29+
sequential = _properties.Boolean('sequential')
30+
31+
typ = _properties.String('type', default='AttributeAccumulation', deserializable=False)
32+
33+
def __init__(self,
34+
name: str,
35+
*,
36+
description: str,
37+
attributes: List[Descriptor],
38+
sequential: bool):
39+
self.name: str = name
40+
self.description: str = description
41+
self.attributes: List[Descriptor] = attributes
42+
self.sequential: bool = sequential
43+
44+
def __str__(self):
45+
return '<AttributeAccumulationPredictor {!r}>'.format(self.name)

src/citrine/informatics/predictors/graph_predictor.py

+69
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,72 @@ def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction:
118118
path = self._path() + '/predict'
119119
res = self._session.post_resource(path, predict_request.dump(), version=self._api_version)
120120
return SinglePrediction.build(res)
121+
122+
def _convert_to_multistep(self) -> "GraphPredictor":
123+
"""Make the GraphPredictor look as if generated with a MULTISTEP_MATERIALS datasource."""
124+
from citrine.informatics.predictors import (
125+
AttributeAccumulationPredictor, MolecularStructureFeaturizer,
126+
LabelFractionsPredictor, SimpleMixturePredictor, IngredientFractionsPredictor,
127+
AutoMLPredictor, MeanPropertyPredictor, ChemicalFormulaFeaturizer
128+
)
129+
130+
automl_outputs = {}
131+
featurizer_outputs = set()
132+
automl_inputs = {}
133+
134+
for predictor in self.predictors:
135+
if isinstance(predictor, AttributeAccumulationPredictor):
136+
raise ValueError("Graph already contains Attribute Accumulation nodes")
137+
elif isinstance(predictor, AutoMLPredictor):
138+
for descriptor in predictor.outputs:
139+
automl_outputs[descriptor.key] = descriptor
140+
for descriptor in predictor.inputs:
141+
automl_inputs[descriptor.key] = descriptor
142+
elif isinstance(predictor, MeanPropertyPredictor):
143+
for descriptor in predictor.properties:
144+
featurizer_outputs.add(
145+
f"mean of property {descriptor.key} in {predictor.input_descriptor.key}"
146+
)
147+
elif isinstance(predictor, IngredientFractionsPredictor):
148+
for ingredient in predictor.ingredients:
149+
featurizer_outputs.add(
150+
f"{ingredient} share in {predictor.input_descriptor.key}"
151+
)
152+
elif isinstance(predictor, LabelFractionsPredictor):
153+
for label in predictor.labels:
154+
featurizer_outputs.add(
155+
f"{label} share in {predictor.input_descriptor.key}"
156+
)
157+
elif isinstance(predictor, (SimpleMixturePredictor, ChemicalFormulaFeaturizer,
158+
MolecularStructureFeaturizer)):
159+
pass
160+
else:
161+
# IngredientsToFormulationRelation, ExpressionPredictor,
162+
# IngredientsToFormulationPredictor
163+
raise NotImplementedError(f"Unhandled predictor type: {type(predictor)}")
164+
165+
output_accumulator = AttributeAccumulationPredictor(
166+
name="Output variable accumulation",
167+
description="Output variables encountered in the material history. "
168+
"Only sequential mixing steps are considered.",
169+
attributes=list(automl_outputs.values()),
170+
sequential=True
171+
)
172+
input_accumulator = AttributeAccumulationPredictor(
173+
name="Attribute accumulation",
174+
description="Parameters/conditions encountered in the material history. "
175+
"Most recent values are selected first.",
176+
attributes=[automl_inputs[key] for key in automl_inputs
177+
if key not in featurizer_outputs],
178+
sequential=False
179+
)
180+
181+
update = GraphPredictor(
182+
name=self.name,
183+
description=self.description,
184+
predictors=self.predictors + [output_accumulator, input_accumulator],
185+
training_data=self.training_data
186+
)
187+
update.uid = self.uid
188+
189+
return update

src/citrine/informatics/predictors/node.py

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class PredictorNode(PolymorphicSerializable["PredictorNode"], Predictor):
1919
@classmethod
2020
def get_type(cls, data) -> Type['PredictorNode']:
2121
"""Return the subtype."""
22+
from .attribute_accumulation_predictor import AttributeAccumulationPredictor
2223
from .expression_predictor import ExpressionPredictor
2324
from .molecular_structure_featurizer import MolecularStructureFeaturizer
2425
from .ingredients_to_formulation_predictor import IngredientsToFormulationPredictor
@@ -30,6 +31,7 @@ def get_type(cls, data) -> Type['PredictorNode']:
3031
from .chemical_formula_featurizer import ChemicalFormulaFeaturizer
3132
type_dict = {
3233
"AnalyticExpression": ExpressionPredictor,
34+
"AttributeAccumulation": AttributeAccumulationPredictor,
3335
"MoleculeFeaturizer": MolecularStructureFeaturizer,
3436
"IngredientsToSimpleMixture": IngredientsToFormulationPredictor,
3537
"MeanProperty": MeanPropertyPredictor,

src/citrine/resources/table_config.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from citrine.gemtables.variables import (
2424
Variable, IngredientIdentifierByProcessTemplateAndName, IngredientQuantityByProcessAndName,
2525
IngredientQuantityDimension, IngredientIdentifierInOutput, IngredientQuantityInOutput,
26-
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput
26+
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput,
27+
AttributeByTemplateAndObjectTemplate, LocalAttributeAndObject
2728
)
2829

2930
from typing import TYPE_CHECKING
@@ -429,6 +430,28 @@ def add_all_ingredients_in_output(self, *,
429430
new_config.version_uid = copy(self.version_uid)
430431
return new_config
431432

433+
def _convert_to_multistep(self) -> "TableConfig":
434+
"""Convert the TableConfig to look like something generated by MULTISTEP_MATERIALS."""
435+
dup: TableConfig = TableConfig.build(self.dump())
436+
437+
def _convert_local(old: Variable) -> Variable:
438+
if isinstance(old, AttributeByTemplateAndObjectTemplate):
439+
return LocalAttributeAndObject(
440+
name=old.name,
441+
headers=old.headers,
442+
template=old.attribute_template,
443+
object_template=old.object_template,
444+
attribute_constraints=old.attribute_constraints,
445+
type_selector=old.type_selector,
446+
)
447+
else:
448+
return old
449+
450+
dup.variables = [_convert_local(x) for x in dup.variables]
451+
dup.generation_algorithm = TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS
452+
453+
return dup
454+
432455

433456
class TableConfigCollection(Collection[TableConfig]):
434457
"""Represents the collection of all Table Configs associated with a project."""

tests/informatics/test_predictors.py

+61
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,28 @@ def ingredient_fractions_predictor() -> IngredientFractionsPredictor:
200200
)
201201

202202

203+
@pytest.fixture
204+
def input_accumulation_predictor(auto_ml) -> AttributeAccumulationPredictor:
205+
"""Build an accumulation node for model inputs."""
206+
return AttributeAccumulationPredictor(
207+
name='Input accumulation predictor',
208+
description='Bubbles attributes up through the graph',
209+
attributes=auto_ml.inputs,
210+
sequential=False
211+
)
212+
213+
214+
@pytest.fixture
215+
def output_accumulation_predictor(auto_ml) -> AttributeAccumulationPredictor:
216+
"""Build an accumulation node for model outputs."""
217+
return AttributeAccumulationPredictor(
218+
name='Output accumulation predictor',
219+
description='Bubbles attributes up through the graph',
220+
attributes=auto_ml.outputs,
221+
sequential=True
222+
)
223+
224+
203225
def test_simple_report(graph_predictor):
204226
"""Ensures we get a report from a simple predictor post_build call"""
205227
with pytest.raises(ValueError):
@@ -453,6 +475,17 @@ def test_ingredient_fractions_property_initialization(ingredient_fractions_predi
453475
assert str(ingredient_fractions_predictor) == expected_str
454476

455477

478+
def test_attribute_accumulation_predictor_initialization(input_accumulation_predictor, output_accumulation_predictor):
479+
"""Make sure the correct fields go to the correct places for an attribute accumulation predictor."""
480+
assert len(input_accumulation_predictor.attributes) == 2
481+
expected_input = f"<AttributeAccumulationPredictor '{input_accumulation_predictor.name}'>"
482+
assert str(input_accumulation_predictor) == expected_input
483+
484+
assert len(output_accumulation_predictor.attributes) == 1
485+
expected_output = f"<AttributeAccumulationPredictor '{output_accumulation_predictor.name}'>"
486+
assert str(output_accumulation_predictor) == expected_output
487+
488+
456489
def test_status(graph_predictor, valid_graph_predictor_data):
457490
"""Ensure we can check the status of predictor validation."""
458491
# A locally built predictor should be "False" for all status checks
@@ -485,3 +518,31 @@ def test_single_predict(graph_predictor):
485518
prediction_out = graph_predictor.predict(request)
486519
assert prediction_out.dump() == prediction_in.dump()
487520
assert session.post_resource.call_count == 1
521+
522+
def test__convert_to_multistep(molecule_featurizer, auto_ml, mean_property_predictor, ingredient_fractions_predictor,
523+
label_fractions_predictor, expression_predictor, output_accumulation_predictor,
524+
input_accumulation_predictor):
525+
"""Verify graph predictor multistep material update."""
526+
graph_predictor = GraphPredictor(
527+
name='Graph predictor',
528+
description='description',
529+
predictors=[molecule_featurizer, auto_ml, mean_property_predictor, ingredient_fractions_predictor, label_fractions_predictor],
530+
training_data=[data_source, formulation_data_source]
531+
)
532+
updated = graph_predictor._convert_to_multistep()
533+
assert len(updated.predictors) == len(graph_predictor.predictors) + 2
534+
generated_accumulation = [p for p in updated.predictors if isinstance(p, AttributeAccumulationPredictor)]
535+
assert generated_accumulation[0].attributes == output_accumulation_predictor.attributes
536+
assert generated_accumulation[1].attributes == input_accumulation_predictor.attributes
537+
538+
with pytest.raises(ValueError):
539+
updated._convert_to_multistep()
540+
541+
542+
with pytest.raises(NotImplementedError):
543+
GraphPredictor(
544+
name='Graph predictor',
545+
description='description',
546+
predictors=[expression_predictor],
547+
training_data=[data_source, formulation_data_source]
548+
)._convert_to_multistep()

tests/resources/test_table_config.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
IngredientQuantityDimension, IngredientQuantityByProcessAndName, \
1010
IngredientIdentifierByProcessTemplateAndName, TerminalMaterialIdentifier, \
1111
IngredientQuantityInOutput, IngredientIdentifierInOutput, \
12-
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput
12+
IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput, AttributeByTemplateAndObjectTemplate, \
13+
LocalAttribute, LocalAttributeAndObject
1314
from citrine.resources.table_config import TableConfig, TableConfigCollection, TableBuildAlgorithm, \
1415
TableFromGemdQueryAlgorithm
1516
from citrine.resources.data_concepts import CITRINE_SCOPE
@@ -900,3 +901,20 @@ def test_update_unregistered_fail(collection, session):
900901
def test_delete(collection):
901902
with pytest.raises(NotImplementedError):
902903
collection.delete(empty_defn().config_uid)
904+
905+
906+
def test__convert_to_multistep():
907+
variables = [
908+
AttributeByTemplate("One", headers=["one"], template=uuid4()),
909+
AttributeByTemplateAndObjectTemplate("Two", headers=["two"], attribute_template=uuid4(), object_template=uuid4()),
910+
LocalAttribute("Three", headers=["three"], template=uuid4()),
911+
LocalAttributeAndObject("Four", headers=["four"], template=uuid4(), object_template=uuid4()),
912+
]
913+
columns = [MeanColumn(data_source=v.name, target_units="") for v in variables]
914+
config: TableConfig = TableConfig.build(TableConfigDataFactory(
915+
variables=[v.dump() for v in variables],
916+
columns=[c.dump() for c in columns],
917+
))
918+
updated = config._convert_to_multistep()
919+
assert len(config.variables) == len(config.variables)
920+
assert not any(isinstance(x, AttributeByTemplateAndObjectTemplate) for x in updated.variables)

0 commit comments

Comments
 (0)