diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index d13fb87e..ecc8d714 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -849,20 +849,21 @@ def extend( # Update Canonicals if isinstance(c_input := model.canonical_input, Connection): c_input_obj = self.conns.get_con_by_metadata(c_input.data.metadata) - if c_input_obj not in self.dependency_map._local_output_dependency_map: - # Update canonical input with model canonical input - if c_input_obj not in self.conns.input_connections: - self._canonical_input = NOT_AVAILABLE - else: - assert c_input_obj is not None - self._canonical_input = c_input_obj + if c_input_obj is not None and c_input_obj.metadata.data.value is TBD: + if c_input_obj not in self.dependency_map._local_output_dependency_map: + # Update canonical input with model canonical input + if c_input_obj not in self.conns.input_connections: + self._canonical_input = NOT_AVAILABLE + else: + assert c_input_obj is not None + self._canonical_input = c_input_obj - elif ( - self._canonical_input - in self.dependency_map._local_output_dependency_map - ): - # Model canonical output used as input than make it None - self._canonical_input = NOT_AVAILABLE + elif ( + self._canonical_input + in self.dependency_map._local_output_dependency_map + ): + # Model canonical output used as input than make it None + self._canonical_input = NOT_AVAILABLE if isinstance(c_output := model.canonical_output, Connection): c_output_obj = self.conns.get_con_by_metadata(c_output.data.metadata) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index b4a9048d..e1de9596 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -952,6 +952,36 @@ def test_canonical_input_9(): assert model.canonical_output == model.output # type: ignore +def test_canonical_input_10(): + # Valued cannection cannot be canonical input + model = Model() + model += Add()(left=3, right="input2", output="output") + + assert model.canonical_input is NOT_AVAILABLE + + +def test_canonical_input_11(): + # Valued cannection cannot be canonical input + model = Model() + model += (buff := Buffer()) + model += Add()(left=3, right="input2", output="output") + + canonical_input = model.canonical_input + assert not isinstance(canonical_input, NotAvailable) + assert canonical_input.metadata == buff.input.metadata + + +def test_canonical_input_12(): + # Valued cannection cannot be canonical input + model = Model() + model += (buff := Buffer()) + model += Buffer()(input=3 / buff.output) + + canonical_input = model.canonical_input + assert not isinstance(canonical_input, NotAvailable) + assert canonical_input.metadata == buff.input.metadata + + def test_canonical_dual_iadd_op(): model1 = Model() model1 += (c1 := Convolution2D(3, 4))