@@ -200,6 +200,28 @@ def ingredient_fractions_predictor() -> IngredientFractionsPredictor:
200
200
)
201
201
202
202
203
+ @pytest .fixture
204
+ def input_accumulation_predictor (auto_ml ) -> AttributeAccumulationPredictor :
205
+ """Build an accumulation node for model inputs."""
206
+ return AttributeAccumulationPredictor (
207
+ name = 'Input accumulation predictor' ,
208
+ description = 'Bubbles attributes up through the graph' ,
209
+ attributes = auto_ml .inputs ,
210
+ sequential = False
211
+ )
212
+
213
+
214
+ @pytest .fixture
215
+ def output_accumulation_predictor (auto_ml ) -> AttributeAccumulationPredictor :
216
+ """Build an accumulation node for model outputs."""
217
+ return AttributeAccumulationPredictor (
218
+ name = 'Output accumulation predictor' ,
219
+ description = 'Bubbles attributes up through the graph' ,
220
+ attributes = auto_ml .outputs ,
221
+ sequential = True
222
+ )
223
+
224
+
203
225
def test_simple_report (graph_predictor ):
204
226
"""Ensures we get a report from a simple predictor post_build call"""
205
227
with pytest .raises (ValueError ):
@@ -453,6 +475,17 @@ def test_ingredient_fractions_property_initialization(ingredient_fractions_predi
453
475
assert str (ingredient_fractions_predictor ) == expected_str
454
476
455
477
478
+ def test_attribute_accumulation_predictor_initialization (input_accumulation_predictor , output_accumulation_predictor ):
479
+ """Make sure the correct fields go to the correct places for an attribute accumulation predictor."""
480
+ assert len (input_accumulation_predictor .attributes ) == 2
481
+ expected_input = f"<AttributeAccumulationPredictor '{ input_accumulation_predictor .name } '>"
482
+ assert str (input_accumulation_predictor ) == expected_input
483
+
484
+ assert len (output_accumulation_predictor .attributes ) == 1
485
+ expected_output = f"<AttributeAccumulationPredictor '{ output_accumulation_predictor .name } '>"
486
+ assert str (output_accumulation_predictor ) == expected_output
487
+
488
+
456
489
def test_status (graph_predictor , valid_graph_predictor_data ):
457
490
"""Ensure we can check the status of predictor validation."""
458
491
# A locally built predictor should be "False" for all status checks
@@ -485,3 +518,31 @@ def test_single_predict(graph_predictor):
485
518
prediction_out = graph_predictor .predict (request )
486
519
assert prediction_out .dump () == prediction_in .dump ()
487
520
assert session .post_resource .call_count == 1
521
+
522
+ def test__convert_to_multistep (molecule_featurizer , auto_ml , mean_property_predictor , ingredient_fractions_predictor ,
523
+ label_fractions_predictor , expression_predictor , output_accumulation_predictor ,
524
+ input_accumulation_predictor ):
525
+ """Verify graph predictor multistep material update."""
526
+ graph_predictor = GraphPredictor (
527
+ name = 'Graph predictor' ,
528
+ description = 'description' ,
529
+ predictors = [molecule_featurizer , auto_ml , mean_property_predictor , ingredient_fractions_predictor , label_fractions_predictor ],
530
+ training_data = [data_source , formulation_data_source ]
531
+ )
532
+ updated = graph_predictor ._convert_to_multistep ()
533
+ assert len (updated .predictors ) == len (graph_predictor .predictors ) + 2
534
+ generated_accumulation = [p for p in updated .predictors if isinstance (p , AttributeAccumulationPredictor )]
535
+ assert generated_accumulation [0 ].attributes == output_accumulation_predictor .attributes
536
+ assert generated_accumulation [1 ].attributes == input_accumulation_predictor .attributes
537
+
538
+ with pytest .raises (ValueError ):
539
+ updated ._convert_to_multistep ()
540
+
541
+
542
+ with pytest .raises (NotImplementedError ):
543
+ GraphPredictor (
544
+ name = 'Graph predictor' ,
545
+ description = 'description' ,
546
+ predictors = [expression_predictor ],
547
+ training_data = [data_source , formulation_data_source ]
548
+ )._convert_to_multistep ()
0 commit comments