From 04b43feba2bfe63b792b4156c5b189c731f2366b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Thu, 21 Nov 2024 11:50:37 +0300 Subject: [PATCH 1/3] set methods and corresponding tests updated. A minor change in dict_to_shape and shape_to_dict --- mithril/framework/logical/base.py | 116 ++++++++++++++++---------- mithril/utils/dict_conversions.py | 6 +- tests/scripts/test_constant_inputs.py | 24 +++--- tests/scripts/test_set_shapes.py | 51 ++++++++++- tests/scripts/test_set_types.py | 47 +++++++++++ tests/scripts/test_set_values.py | 82 +++++++++++++++++- 6 files changed, 263 insertions(+), 63 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index cb39f014..d3b6f3b5 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -17,6 +17,7 @@ import abc from collections.abc import Callable, Mapping from dataclasses import dataclass +from itertools import chain from types import UnionType from typing import Any @@ -212,38 +213,80 @@ def create_connection(self, metadata: IOHyperEdge, key: str) -> ConnectionData: return con.data def _set_shapes( - self, shapes: ShapesType, *, trace: bool = False, updates: Updates | None = None + self, + shapes: ShapesType, + trace: bool = False, + updates: Updates | None = None, + **kwargs: ShapeTemplateType, ) -> None: - if trace: - self.assigned_shapes.append(shapes) + # Initialize assigned shapes dictionary to store assigned shapes. + assigned_shapes = {} if updates is None: updates = Updates() model = self._get_outermost_parent() - metadatas: OrderedSet[IOHyperEdge] = OrderedSet() + seen_metadata: OrderedSet[IOHyperEdge] = OrderedSet() used_keys: dict[str | int, ShapeType] = {} shape_nodes = {} - for key, shape in shapes.items(): + # for key, shape in shapes.items(): + for key, shape in chain(shapes.items(), kwargs.items()): metadata = self.conns.extract_metadata(key) if metadata is None: raise KeyError("Requires valid IO connection to set shapes!") - if metadata in metadatas: - raise KeyError("shape of same connection has already given") - metadatas.add(metadata) + if metadata in seen_metadata: + raise KeyError("Shape of same connection has already given") + seen_metadata.add(metadata) outer_conn = next(iter(model.conns.metadata_dict[metadata])) + # Convert Connection type keys into corresponding string + # representation. + if isinstance(key, Connection): + key = model.conns.get_con_by_metadata(metadata).key + assigned_shapes[key] = shape given_repr = create_shape_repr(shape, model.constraint_solver, used_keys) shape_nodes[outer_conn.key] = given_repr.node - for key, node in shape_nodes.items(): - shape_node = model.conns.get_shape_node(key) + + if trace: + self.assigned_shapes.append(assigned_shapes) + + for _key, node in shape_nodes.items(): + shape_node = model.conns.get_shape_node(_key) assert shape_node is not None updates |= shape_node.merge(node) model.constraint_solver(updates) - def set_shapes(self, shapes: ShapesType) -> None: - self._set_shapes(shapes, trace=True) + def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates: + """ + Set value for the given connection. + + Args: + key (str | Connection): Connection key or Connection object to set value. + value (MainValueType): Value to set for the given connection. - def set_values(self, values: dict[str | Connection, MainValueType | str]) -> None: + Raises: + KeyError: If the provided key is not a valid IO key. + """ + + if key.key not in self.conns.input_keys: + raise KeyError("Internal or output keys' values cannot be set.") + # Data is scalar, set the value directly. + return key.metadata.data.set_value(value) + + def set_shapes( + self, config: ShapesType | None = None, **kwargs: ShapeTemplateType + ) -> None: + if config is None: + config = {} + self._set_shapes(config, trace=True, updates=None, **kwargs) + + def set_values( + self, + config: Mapping[str | Connection, MainValueType | str] + | Mapping[Connection, MainValueType | str] + | Mapping[str, MainValueType | str] + | None = None, + **kwargs: MainValueType | str, + ) -> None: """ Set multiple values in the model. This method updates values in the outermost model by traversing up the @@ -252,28 +295,27 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non with the updated values. Args: - values (dict[str | Connection, MainValueType]): A dictionary where + config (dict[str | Connection, MainValueType]): A dictionary where keys are either strings or Connection objects, and values are of type MainValueType. + **kwargs (MainValueType): Key-value pairs where keys are string names + of connections present in this model. Raises: KeyError: If a valid key or Connection is not provided in the values dictionary. """ - - # Make all value updates in the outermost model. + if config is None: + config = {} + # Make all value updates in the outermost model.s model = self._get_outermost_parent() - updates = Updates() - - for key, value in values.items(): + for key, value in chain(config.items(), kwargs.items()): # Perform metadata extraction process on self. metadata = self.conns.extract_metadata(key) - # Perform validity check and updates on model. if (conn_data := model.conns.get_con_by_metadata(metadata)) is None: raise KeyError("Requires valid key or Connection to set values!") - updates |= model._set_value(conn_data, value) # Solve constraints with the updated values. @@ -281,9 +323,11 @@ def set_values(self, values: dict[str | Connection, MainValueType | str]) -> Non def set_types( self, - types: Mapping[str | Connection, type | UnionType] + config: Mapping[str | Connection, type | UnionType] | Mapping[Connection, type | UnionType] - | Mapping[str, type | UnionType], + | Mapping[str, type | UnionType] + | None = None, + **kwargs: type | UnionType, ): """ Set types of any connection in the Model @@ -300,33 +344,19 @@ def set_types( of type of type or UnionType objects. """ - # get the outermost parent as all the updates will happen here + if config is None: + config = {} + + # Get the outermost parent as all the updates will happen here. model = self._get_outermost_parent() updates = Updates() - for key, key_type in types.items(): + for key, key_type in chain(config.items(), kwargs.items()): metadata = self.conns.extract_metadata(key) data = metadata.data updates |= data.set_type(key_type) - # run the constraints for updating affected connections + # Run the constraints for updating affected connections. model.constraint_solver(updates) - def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates: - """ - Set value for the given connection. - - Args: - key (str | Connection): Connection key or Connection object to set value. - value (MainValueType): Value to set for the given connection. - - Raises: - KeyError: If the provided key is not a valid IO key. - """ - - if key.key not in self.conns.input_keys: - raise KeyError("Internal or output keys' values cannot be set.") - # Data is scalar, set the value directly. - return key.metadata.data.set_value(value) - def get_shapes( self, uni_keys=None, var_keys=None, symbolic=True, verbose=False ) -> _ShapesType: diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index e5c43287..664c8e6e 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -549,7 +549,7 @@ def shape_to_dict(shapes): if isinstance(item, tuple): # variadic shape_list.append(f"{item[0]},...") else: - shape_list.append(str(item)) + shape_list.append(item) shape_dict[key] = shape_list return shape_dict @@ -559,10 +559,10 @@ def dict_to_shape(shape_dict): for key, shape_list in shape_dict.items(): shapes[key] = [] for shape in shape_list: - if "..." in shape: + if isinstance(shape, str) and "..." in shape: shapes[key].append((shape.split(",")[0], ...)) else: - shapes[key].append(int(shape)) + shapes[key].append(shape) return shapes diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 221bcc58..70a8ff13 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -1764,13 +1764,12 @@ def test_unused_cached_values_1_set_values(): model = Model() linear_model = Linear(dimension=2) model += linear_model() - model.set_values( - { - linear_model.w: [[1.0, 2.0]], - linear_model.b: [3.0, 1.0], - linear_model.input: [[3.0], [2.0]], - } - ) + config: dict[Connection, list] = { + linear_model.w: [[1.0, 2.0]], + linear_model.b: [3.0, 1.0], + linear_model.input: [[3.0], [2.0]], + } + model.set_values(config) comp_model = mithril.compile(model=model, backend=(backend := NumpyBackend())) dtype = backend.get_backend_array_type() cache = comp_model.data_store.data_values @@ -1835,12 +1834,11 @@ def test_unused_cached_values_2_set_values(): model = Model() linear_model = Linear(dimension=2) model += linear_model() - model.set_values( - { - linear_model.w: [[1.0, 2.0]], - linear_model.b: [3.0, 1.0], - } - ) + config: dict[Connection, list] = { + linear_model.w: [[1.0, 2.0]], + linear_model.b: [3.0, 1.0], + } + model.set_values(config) comp_model = mithril.compile( model=model, backend=(backend := NumpyBackend()), safe_names=False ) diff --git a/tests/scripts/test_set_shapes.py b/tests/scripts/test_set_shapes.py index 6b405df6..698becfd 100644 --- a/tests/scripts/test_set_shapes.py +++ b/tests/scripts/test_set_shapes.py @@ -38,6 +38,53 @@ def test_set_shapes_1(): check_shapes_semantically(ref_shapes, model.shapes) +def test_set_shapes_1_kwargs_arg(): + model = Model() + + model += Sigmoid()("input1", IOKey("output1")) + model += Sigmoid()("input2", IOKey("output2")) + + model.set_shapes(input1=["a", "b"], input2=["b", "a"]) + + ref_shapes = { + "input1": ["a", "b"], + "output1": ["a", "b"], + "input2": ["b", "a"], + "output2": ["b", "a"], + } + + check_shapes_semantically(ref_shapes, model.shapes) + + +def test_set_shapes_1_hybrid_arg(): + model = Model() + + model += Sigmoid()("input1", IOKey("output1")) + model += Sigmoid()("input2", IOKey("output2")) + + model.set_shapes({"input1": ["a", "b"]}, input2=["b", "a"]) + + ref_shapes = { + "input1": ["a", "b"], + "output1": ["a", "b"], + "input2": ["b", "a"], + "output2": ["b", "a"], + } + + check_shapes_semantically(ref_shapes, model.shapes) + + +def test_set_shapes_1_hybrid_arg_same_metadata(): + model = Model() + + model += Sigmoid()("input1", IOKey("output1")) + model += Sigmoid()("input2", IOKey("output2")) + + with pytest.raises(KeyError) as err_info: + model.set_shapes({model.input2: ["a", "b"]}, input2=["b", "a"]) # type: ignore + assert str(err_info.value) == "'Shape of same connection has already given'" + + def test_set_shapes_2(): model = Model() @@ -113,7 +160,7 @@ def test_set_shapes_6(): model3 += model2(left="left", right="right", output=IOKey("output")) model4 += model3(left="left", right="right", output=IOKey("output")) - model3.set_shapes({"left": [3, 4], add1.right: [3, 4], model4.output: [3, 4]}) # type: ignore + model3.set_shapes({add1.right: [3, 4], model4.output: [3, 4]}, left=[3, 4]) # type: ignore ref_shapes = {"left": [3, 4], "right": [3, 4], "output": [3, 4]} @@ -166,4 +213,4 @@ def test_set_shapes_7_error(): with pytest.raises(KeyError) as err_info: model3.set_shapes({"left": [3, 4], add1.left: [3, 4], model4.output: [3, 4]}) # type: ignore - assert str(err_info.value) == "'shape of same connection has already given'" + assert str(err_info.value) == "'Shape of same connection has already given'" diff --git a/tests/scripts/test_set_types.py b/tests/scripts/test_set_types.py index e69d74da..acb42c98 100644 --- a/tests/scripts/test_set_types.py +++ b/tests/scripts/test_set_types.py @@ -27,6 +27,15 @@ def test_set_types_1(): assert input_data._type is int +def test_set_types_1_kwargs_arg(): + model = Model() + sig_model = Sigmoid() + model += sig_model(input="input", output=IOKey("output")) + model.set_types(input=int) + input_data = sig_model.input.metadata.data + assert input_data._type is int + + def test_set_types_2(): model = Model() buffer_model = Buffer() @@ -36,6 +45,15 @@ def test_set_types_2(): assert input_data._type == int | bool +def test_set_types_2_kwargs_arg(): + model = Model() + buffer_model = Buffer() + model += buffer_model(input="input", output=IOKey(name="output")) + model.set_types(input=int | bool) + input_data = buffer_model.input.metadata.data + assert input_data._type == int | bool + + def test_set_types_3(): model = Model() buffer_model = Buffer() @@ -45,6 +63,24 @@ def test_set_types_3(): assert input_data._type == int | bool +def test_set_types_3_kwargs_arg_1(): + model = Model() + buffer_model = Buffer() + model += buffer_model(input="input", output=IOKey(name="output")) + model.set_types(input=int | bool) + input_data = buffer_model.input.metadata.data + assert input_data._type == int | bool + + +def test_set_types_3_kwargs_arg_2(): + model = Model() + buffer_model = Buffer() + model += buffer_model(input="input", output=IOKey(name="output")) + buffer_model.set_types(input=int | bool) + input_data = buffer_model.input.metadata.data + assert input_data._type == int | bool + + def test_set_types_4(): model = Model() buffer_model = Buffer() @@ -67,6 +103,17 @@ def test_set_types_5(): assert input_data_2._type is float +def test_set_types_5_key_error(): + model = Model() + buffer_model_1 = Buffer() + buffer_model_2 = Buffer() + model += buffer_model_1(input="input1", output=IOKey(name="output1")) + model += buffer_model_2(input="input2", output=IOKey(name="output2")) + with pytest.raises(KeyError) as err_info: + model.set_types({model.input1: int | bool, "input": float}) # type: ignore + assert str(err_info.value) == "'Key input is not found in connections.'" + + def test_set_types_6(): model = Model() buffer_model_1 = Buffer() diff --git a/tests/scripts/test_set_values.py b/tests/scripts/test_set_values.py index 2ec8dd5c..5b59da2e 100644 --- a/tests/scripts/test_set_values.py +++ b/tests/scripts/test_set_values.py @@ -16,7 +16,7 @@ import mithril from mithril import JaxBackend -from mithril.models import TBD, Add, IOKey, Linear, Mean, Model, Relu, Shape +from mithril.models import TBD, Add, Connection, IOKey, Linear, Mean, Model, Relu, Shape from ..utils import check_evaluations, compare_models, init_params from .test_utils import assert_results_equal @@ -92,6 +92,28 @@ def test_set_values_scalar_1(): assert_results_equal(outputs, ref_outputs) +def test_set_values_scalar_1_kwargs_arg(): + backend = JaxBackend() + model = Model() + mean_model = Mean(axis=TBD) + model += mean_model(input="input", output=IOKey("output", shape=[2, 2])) + mean_model.set_values(axis=1) + + pm = mithril.compile(model=model, backend=JaxBackend()) + params = {"input": backend.ones(2, 2)} + data: dict = {} + gradients = {"output": backend.ones(2)} + + ref_outputs = {"output": backend.ones(2)} + + ref_grads = {"input": backend.ones(2, 2) / 2} + + outputs, grads = pm.evaluate_all(params, data, gradients) + + assert_results_equal(grads, ref_grads) + assert_results_equal(outputs, ref_outputs) + + def test_set_values_scalar_2(): backend = JaxBackend() model = Model() @@ -167,13 +189,46 @@ def test_set_values_scalar_6(): mean_model = Mean(axis=TBD) model += mean_model(input="input", axis="axis", output="output") with pytest.raises(ValueError) as err_info: - model.set_values({"axis": (0, 1), mean_model.axis: (0, 2)}) + config: dict[str | Connection, tuple[int, int]] = { + "axis": (0, 1), + mean_model.axis: (0, 2), + } + model.set_values(config) assert ( str(err_info.value) == "Value is set before as (0, 1). A scalar value can not be reset." ) +def test_set_values_scalar_6_kwargs_arg(): + model = Model() + mean_model = Mean(axis=TBD) + model += mean_model(input="input", axis="axis", output="output") + with pytest.raises(ValueError) as err_info: + config = {mean_model.axis: (0, 2)} + model.set_values(config, axis=(0, 1)) + assert ( + str(err_info.value) + == "Value is set before as (0, 2). A scalar value can not be reset." + ) + + +def test_set_values_scalar_6_same_conn_in_config(): + model = Model() + mean_model = Mean(axis=TBD) + model += mean_model(input="input", axis="axis", output="output") + with pytest.raises(ValueError) as err_info: + config: dict[Connection | str, tuple[int, int]] = { + mean_model.axis: (0, 2), + "axis": (0, 1), + } + model.set_values(config) + assert ( + str(err_info.value) + == "Value is set before as (0, 2). A scalar value can not be reset." + ) + + def test_set_values_tensor_1(): backend = JaxBackend() @@ -197,6 +252,29 @@ def test_set_values_tensor_1(): assert_results_equal(ref_outputs, outputs) +def test_set_values_tensor_1_kwargs_arg(): + backend = JaxBackend() + + model1 = Model() + add_model_1 = Add() + + model1 += add_model_1(left="input1", right="input2", output=IOKey("output")) + + model2 = Model() + add_model_2 = Add() + model2 += model1(input1="input1", input2="sub_input", output=IOKey("output")) + model2 += add_model_2(left="input1", right="input2", output="sub_input") + # add_model_2.set_values({"right": [2.0]}) + model2.set_values({"input1": [3.0]}, input2=[2.0]) + pm = mithril.compile(model=model2, backend=JaxBackend()) + + ref_outputs = {"output": backend.array([8.0])} + + outputs = pm.evaluate() + + assert_results_equal(ref_outputs, outputs) + + def test_set_values_tensor_2(): backend = JaxBackend() From e418d20ce447ea97a17158a832b2e4d7cb65c658 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Thu, 21 Nov 2024 14:10:05 +0300 Subject: [PATCH 2/3] Merge bug fixed. --- mithril/framework/logical/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index d4331998..b129d73a 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -313,7 +313,7 @@ def set_values( updates = Updates() # TODO: Currently Setting values in fozen models are prevented only for Tensors. # Scalar and Tensors should not be operated differently. This should be fixed. - for key, value in chain(config.items(), kwargs.items()): + for key in chain(config, kwargs): metadata = self.conns.extract_metadata(key) if isinstance(metadata.data, Tensor) and model.is_frozen: conn_data = model.conns.get_con_by_metadata(metadata) @@ -323,7 +323,6 @@ def set_values( ) for key, value in chain(config.items(), kwargs.items()): - # Perform metadata extraction process on self. metadata = self.conns.extract_metadata(key) # Perform validity check and updates on model. From 6ee875fcf2c7af7516ccc77f397191f4b1856a00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Fri, 22 Nov 2024 12:38:00 +0300 Subject: [PATCH 3/3] Bug in _set_shapes fixed. --- mithril/framework/logical/base.py | 30 ++++++++-------- tests/scripts/test_model_to_dict_rtt.py | 47 +++++++++++++++++++++++++ tests/scripts/test_set_shapes.py | 34 +++++++----------- 3 files changed, 74 insertions(+), 37 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index b129d73a..302eb138 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -227,33 +227,31 @@ def _set_shapes( updates = Updates() model = self._get_outermost_parent() - seen_metadata: OrderedSet[IOHyperEdge] = OrderedSet() used_keys: dict[str | int, ShapeType] = {} shape_nodes = {} - # for key, shape in shapes.items(): + # TODO: Can this be refactored to use a single loop? for key, shape in chain(shapes.items(), kwargs.items()): metadata = self.conns.extract_metadata(key) if metadata is None: raise KeyError("Requires valid IO connection to set shapes!") - if metadata in seen_metadata: - raise KeyError("Shape of same connection has already given") - seen_metadata.add(metadata) - outer_conn = next(iter(model.conns.metadata_dict[metadata])) - # Convert Connection type keys into corresponding string - # representation. - if isinstance(key, Connection): - key = model.conns.get_con_by_metadata(metadata).key - assigned_shapes[key] = shape given_repr = create_shape_repr(shape, model.constraint_solver, used_keys) - shape_nodes[outer_conn.key] = given_repr.node + # Get inner string representation of the metadata and save + # use this name in order to merge . + conn = self.conns.get_con_by_metadata(metadata) + assert conn is not None + inner_key = conn.key + shape_nodes[key] = (given_repr.node, inner_key) + assigned_shapes[inner_key] = shape + # Apply updates to the shape nodes. + for key in chain(shapes, kwargs): + node, _inner_key = shape_nodes[key] + shape_node = self.conns.get_shape_node(_inner_key) + assert shape_node is not None + updates |= shape_node.merge(node) if trace: self.assigned_shapes.append(assigned_shapes) - for _key, node in shape_nodes.items(): - shape_node = model.conns.get_shape_node(_key) - assert shape_node is not None - updates |= shape_node.merge(node) model.constraint_solver(updates) def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates: diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index d768cee4..f61ed6b4 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -14,6 +14,8 @@ import re +import pytest + import mithril from mithril import JaxBackend, TorchBackend from mithril.framework.common import TBD, IOKey @@ -58,6 +60,51 @@ def test_linear_expose(): ) +def test_linear_expose_set_shapes(): + model = Model() + lin_1 = Linear() + lin_2 = Linear() + model += lin_1(input="input", w="w") + model += lin_2(input=lin_1.output, w="w1", output=IOKey(name="output2")) + model.set_shapes({lin_1.b: [42]}) + model.set_shapes({lin_2.b: [21]}) + model_dict_created = dict_conversions.model_to_dict(model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert model.shapes == model_recreated.shapes + assert_models_equal(model, model_recreated) + + backend = JaxBackend(precision=64) + assert_evaluations_equal( + model, model_recreated, backend, static_keys={"input": backend.ones([4, 256])} + ) + + +@pytest.mark.skip(reason="Dict conversion does not support extend from inputs yet") +def test_linear_expose_set_shapes_extend_from_inputs(): + model = Model() + lin_1 = Linear() + lin_2 = Linear() + model += lin_2(w="w1", output=IOKey(name="output2")) + model += lin_1(input="input", w="w", output=lin_2.input) + model.set_shapes({lin_1.b: [42]}) + model.set_shapes({lin_2.b: [21]}) + model_dict_created = dict_conversions.model_to_dict(model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert model.shapes == model_recreated.shapes + assert_models_equal(model, model_recreated) + + backend = JaxBackend(precision=64) + assert_evaluations_equal( + model, model_recreated, backend, static_keys={"input": backend.ones([4, 256])} + ) + + def test_linear_set_diff(): model = Model() linear = Linear(dimension=42) diff --git a/tests/scripts/test_set_shapes.py b/tests/scripts/test_set_shapes.py index 698becfd..15dc8942 100644 --- a/tests/scripts/test_set_shapes.py +++ b/tests/scripts/test_set_shapes.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from mithril.framework.common import IOKey from mithril.models import Add, Model, Sigmoid @@ -74,15 +73,24 @@ def test_set_shapes_1_hybrid_arg(): check_shapes_semantically(ref_shapes, model.shapes) -def test_set_shapes_1_hybrid_arg_same_metadata(): +def test_set_shapes_1_hybrid_arg_same_metadata_1(): model = Model() model += Sigmoid()("input1", IOKey("output1")) model += Sigmoid()("input2", IOKey("output2")) - with pytest.raises(KeyError) as err_info: - model.set_shapes({model.input2: ["a", "b"]}, input2=["b", "a"]) # type: ignore - assert str(err_info.value) == "'Shape of same connection has already given'" + model.set_shapes({model.input2: ["a", "b"]}, input2=[2, 3]) # type: ignore + assert model.shapes[model.input2.key] == [2, 3] # type: ignore + + +def test_set_shapes_1_hybrid_arg_same_metadata_2(): + model = Model() + + model += Sigmoid()("input1", IOKey("output1")) + model += Sigmoid()("input2", IOKey("output2")) + + model.set_shapes({model.input2: [2, 3]}, input2=["a", "b"]) # type: ignore + assert model.shapes[model.input2.key] == [2, 3] # type: ignore def test_set_shapes_2(): @@ -198,19 +206,3 @@ def test_set_shapes_8(): "output": ["(V1, ...)"], } check_shapes_semantically(ref_shapes, model.shapes) - - -def test_set_shapes_7_error(): - model1 = Model() - model2 = Model() - model3 = Model() - model4 = Model() - - model1 += (add1 := Add())(left="left", right="right", output=IOKey("output")) - model2 += model1(left="left", right="right", output=IOKey("output")) - model3 += model2(left="left", right="right", output=IOKey("output")) - model4 += model3(left="left", right="right", output=IOKey("output")) - - with pytest.raises(KeyError) as err_info: - model3.set_shapes({"left": [3, 4], add1.left: [3, 4], model4.output: [3, 4]}) # type: ignore - assert str(err_info.value) == "'Shape of same connection has already given'"