From 26dd03ad6cd5ddab7b1f6872fcc644cfdf1bc24f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Wed, 5 Mar 2025 11:56:51 +0300 Subject: [PATCH 1/7] First draft for tracing set attributes bu user for recreation of models from json or dict. --- examples/flux/util.py | 12 +- mithril/cores/python/numpy/ops.py | 2 +- mithril/framework/common.py | 44 +- mithril/framework/logical/base.py | 160 ++- mithril/framework/logical/model.py | 7 +- mithril/framework/logical/operator.py | 2 + mithril/framework/logical/primitive.py | 6 +- mithril/models/models.py | 2 + mithril/utils/dict_conversions.py | 338 +++++-- .../json_files/integration_directed_test.json | 112 ++- tests/json_files/models_directed_test.json | 908 ++++++++++-------- tests/json_files/no_grad_inference_test.json | 15 + .../randomized_model_tests_all_backends.json | 17 +- tests/scripts/helper.py | 14 +- .../summary_txts/test_physical_summary_17 | 22 +- ...test_physical_summary_3_logical_with_depth | 10 +- tests/scripts/test_all_models.py | 53 +- tests/scripts/test_differentiablity.py | 34 +- tests/scripts/test_model_to_dict_rtt.py | 87 +- 19 files changed, 1181 insertions(+), 664 deletions(-) diff --git a/examples/flux/util.py b/examples/flux/util.py index c6a583f9..0c7ada72 100644 --- a/examples/flux/util.py +++ b/examples/flux/util.py @@ -159,11 +159,11 @@ def load_flow_model(name: str, backend: ml.Backend, hf_download: bool = True): ckpt_path = configs[name].ckpt_path if ( ckpt_path is None - and configs[name].repo_id is not None - and configs[name].repo_flow is not None + and (r_id := configs[name].repo_id) is not None + and (r_flow := configs[name].repo_flow) is not None and hf_download ): - ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + ckpt_path = hf_hub_download(r_id, r_flow) flux_lm = flux(configs[name].params) flux_pm = ml.compile( @@ -187,11 +187,11 @@ def load_decoder( ckpt_path = configs[name].ae_path if ( ckpt_path is None - and configs[name].repo_id is not None - and configs[name].repo_ae is not None + and (r_id := configs[name].repo_id) is not None + and (r_ae := configs[name].repo_ae) is not None and hf_download ): - ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + ckpt_path = hf_hub_download(r_id, r_ae) # Loading the autoencoder print("Init AE") diff --git a/mithril/cores/python/numpy/ops.py b/mithril/cores/python/numpy/ops.py index aef303a3..225c19d4 100644 --- a/mithril/cores/python/numpy/ops.py +++ b/mithril/cores/python/numpy/ops.py @@ -629,7 +629,7 @@ def scaled_dot_product_attention( ) L, S = query.shape[-2], key.shape[-2] - scale_factor = 1 / np.sqrt(query.shape[-1]) if scale is None else scale + scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale write_into_cache(cache, "scale_factor", scale_factor) attn_bias = np.zeros((L, S), dtype=query.dtype) if is_causal: diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 6988f228..e6baca95 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -1043,7 +1043,7 @@ def __init__( value: TensorValueType | ToBeDetermined = TBD, type: _TensorTypes = int | float | bool, shape: ShapeNode | None = None, - differentiable: bool = False, + differentiable: bool | None = None, ): if shape is None: # If shape is not provided, create a new shape with a Variadic root. @@ -1119,8 +1119,16 @@ def match(self, other: Tensor[int | float | bool]) -> Updates: updates |= non_valued.set_value(valued.value) self.differentiable = False other.differentiable = False - else: - self.differentiable |= other.differentiable + elif self.differentiable is None: + self.differentiable = other.differentiable + # Differentiable tensors can only be float type. + if self.differentiable: + updates |= self.set_type(float) + elif ( + other.differentiable is not None + and self.differentiable != other.differentiable + ): + raise ValueError("Differentiability mismatch!") # Match shapes. updates |= self.match_shapes(other.shape) updates.shape_updates.discard(other) @@ -1180,8 +1188,16 @@ def _temp_shape(self) -> ShapeRepr | None: return None @property - def differentiable(self) -> bool: - return isinstance(self._value, Tensor) and self._value.differentiable + def differentiable(self) -> bool | None: + if isinstance(self._value, Tensor): + return self._value.differentiable + elif self.is_scalar: + # Scalars are always non-differentiable. + return False + # Differentiability of polymorphic edges are defined + # as None. Depending on its instant type updates, it can + # become True or False (e.g Tensor or int type). + return None @property def tensors(self) -> set[Tensor[int | float | bool]]: @@ -1405,7 +1421,7 @@ def set_value( updates.value_updates.add(self) # Update new type without automatic tensor value creation. updates |= self.set_type(find_type(self._value), create_tensor=False) - if self.is_valued: + if self.is_tensor and self.is_valued: self.set_differentiability(False) return updates @@ -1431,13 +1447,19 @@ def match(self, other: IOHyperEdge) -> Updates: return updates - def set_differentiability(self, differentiable: bool) -> None: - if self.is_tensor: - assert isinstance(self._value, Tensor) - self._value.differentiable = differentiable - elif differentiable: + def set_differentiability(self, differentiable: bool) -> Updates: + if self.is_scalar and differentiable: raise ValueError("Non-tensor edges cannot be differentiable.") + updates = Updates() + if differentiable: + # Differentiable edges can only be Tensor[float] type. + updates |= self.set_type(Tensor[float]) + # Set differentiability of the _value if it is a Tensor. + if isinstance(self._value, Tensor): + self._value.differentiable = differentiable + return updates + def add_constraint(self, constraint: Constraint) -> None: for type in constraint.types: self.constraints[type].add(constraint) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 6c82ae7a..951274f2 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -75,7 +75,7 @@ def __init__( | ScalarType | None = None, expose: bool | None = None, - differentiable: bool = False, + differentiable: bool | None = None, interval: list[float | int] | None = None, ) -> None: # If shape is provided, type should be Tensor. @@ -166,15 +166,7 @@ def __hash__(self) -> int: return hash(id(self)) def set_differentiability(self, differentiable: bool = True) -> Updates: - updates = Updates() - # TODO: Move this method to Model class as set_shapes, set_types etc. - if self.metadata.is_tensor: - self.metadata.set_differentiability(differentiable) - elif differentiable: - updates |= self.metadata.set_type(Tensor[float]) - self.metadata.set_differentiability(differentiable) - - return updates + return self.metadata.set_differentiability(differentiable) BaseKey = ConnectionData @@ -380,14 +372,18 @@ def __init__( ) -> None: self.dag: dict[BaseModel, dict[str, ConnectionData]] = {} self._formula_key: str | None = formula_key - # TODO: maybe set it only to Operator / Model. self.parent: BaseModel | None = None - self.assigned_shapes: list[dict[str, ShapeTemplateType]] = [] - self.assigned_types: dict[ - str, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] = {} + self.assigned_shapes: list[ + list[tuple[tuple[BaseModel, int], ShapeTemplateType]] + ] = [] + self.assigned_types: list[ + tuple[ + tuple[BaseModel, int], + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] + ] = [] + self.assigned_differentiabilities: list[tuple[tuple[BaseModel, int], bool]] = [] self.assigned_constraints: list[AssignedConstraintType] = [] self.conns = Connections() self.frozen_attributes: list[str] = [] @@ -510,12 +506,14 @@ def _add_connection( local_key: str, given_connection: ConnectionDataType, updates: Updates, + trace: bool, ) -> tuple[ConnectionData, Updates]: is_input = local_key in model.input_keys local_connection = model.conns.get_connection(local_key) assert local_connection is not None, "Connection is not found!" edge = local_connection.metadata is_not_valued = not edge.is_valued + set_diff = None d_map = self.dependency_map.local_output_dependency_map @@ -527,6 +525,7 @@ def _add_connection( | Tensor[int | float | bool] | NullConnection ) = NOT_GIVEN + set_type: type[Tensor[int | float | bool]] | ScalarType = ToBeDetermined is_new_connection = False match given_connection: @@ -545,6 +544,10 @@ def _add_connection( set_value = given_connection given_connection = self._create_connection(edge, None) assert isinstance(given_connection, ConnectionData) + + if given_connection.metadata.differentiable != edge.differentiable: + set_diff = given_connection.metadata.differentiable + # Connection is given as a Connection object. if ( con_obj := self.conns.get_con_by_metadata(given_connection.metadata) @@ -554,6 +557,8 @@ def _add_connection( is_new_connection = True expose = given_connection.is_exposed outer_key = given_connection.get_key() + if set_value is NOT_GIVEN: + set_type = given_connection.metadata.edge_type if set_value is NOT_GIVEN and given_connection.metadata.value is not TBD: set_value = given_connection.metadata._value if outer_key is not None: @@ -579,15 +584,22 @@ def _add_connection( "Expose flag cannot be false when " "no value is provided for input keys!" ) + + # Set value or type if given. if not isinstance(set_value, NullConnection): updates |= con_obj.metadata.set_value(set_value) - elif ( - set_type := given_connection.metadata.edge_type - ) is not ToBeDetermined: - model.set_types({local_connection: set_type}) - - if given_connection.metadata.differentiable: - updates |= con_obj.set_differentiability(True) + elif set_type is not ToBeDetermined: + # Skip tracing if the local connection's type is already + # set to the given type. + trace &= set_type != local_connection.metadata.edge_type + model._set_types({local_connection: set_type}, trace=trace) + + # Set differentiability if given. + if set_diff is not None: + # No need to trace differentiability for valued and + # existing connections. + trace &= is_new_connection and not given_connection.metadata.is_valued + model._set_differentiability({local_connection: set_diff}, trace) else: if given_connection in model.conns.all.values(): @@ -655,6 +667,7 @@ def _add_connection( or con_obj in self.conns.output_connections ): self.conns.couts.add(con_obj) + return con_obj, updates def rename_key(self, connection: ConnectionData, key: str) -> None: @@ -787,6 +800,8 @@ def _merge_connections( def extend( self, model: BaseModel | BaseModel, + trace: bool = True, + /, **kwargs: ConnectionDataType, ) -> None: # Check possible errors before the extension. @@ -836,7 +851,9 @@ def extend( ): value._expose = True - con_obj, _updates = self._add_connection(model, local_key, value, updates) + con_obj, _updates = self._add_connection( + model, local_key, value, updates, trace + ) updates |= _updates submodel_dag[local_key] = con_obj if tensors := con_obj.metadata.tensors: @@ -848,12 +865,12 @@ def extend( submodel_dag[key].key: template for key, template in shape_info.items() } - # Set given shapes. - self._set_shapes(**shape_info) # TODO: Should "trace" be set to True?. - self.constraint_solver(updates) + # # Set given shapes. + # self._set_shapes(**shape_info) # TODO: Should "trace" be set to True?. + # self.constraint_solver(updates) - model.constraint_solver.clear() - model.conns.connections_dict = {} + # model.constraint_solver.clear() + # model.conns.connections_dict = {} # Insert to self dag as a FrozenDict."" # Since we update dag in merge_connections, we could not use FrozenDict. @@ -861,6 +878,13 @@ def extend( self.dependency_map.add_model_dag(model, model_dag) + # Set given shapes. + self._set_shapes(**shape_info) # TODO: Should "trace" be set to True?. + self.constraint_solver(updates) + + model.constraint_solver.clear() + model.conns.connections_dict = {} + # Update jittablity by using model's jittablity. self._jittable &= model.jittable @@ -1356,8 +1380,48 @@ def _create_connection( connection.metadata = metadata return connection - def set_differentiability( - self, config: dict[ConnectionData, bool] | None = None, /, **kwargs: bool + def extract_model_key_index( + self, connection: ConnectionData + ) -> tuple[BaseModel, int]: + if connection not in self.conns.all.values(): + raise KeyError(f"Connection {connection.key} is not found in the model") + + if self.is_frozen: + _model = self + name = connection._name + assert name is not None + key_index = list(self.external_keys).index(name) + else: + assert connection.model is not None + key = connection.key + if ( + _model := connection.model + ) is self and connection.key not in self.external_keys: + # Always store info using freezed models. + model_info: ( + list[tuple[BaseModel, OrderedSet[ConnectionData]]] + | tuple[BaseModel, OrderedSet[ConnectionData]] + ) = self.dependency_map.local_input_dependency_map.get( + connection, + self.dependency_map.local_output_dependency_map.get(connection), # type: ignore + ) + _model = ( + model_info[0][0] if isinstance(model_info, list) else model_info[0] + ) + # Get corresponding connection of the _model. + _model_conn = _model.conns.get_con_by_metadata(connection.metadata) + assert _model_conn is not None + key = _model_conn.key + key_index = list(_model.external_keys).index(key) + + return (_model, key_index) + + def _set_differentiability( + self, + config: dict[ConnectionData, bool] | None = None, + trace: bool = False, + /, + **kwargs: bool, ) -> None: updates = Updates() if config is None: @@ -1373,12 +1437,21 @@ def set_differentiability( elif isinstance(key, ConnectionData): if key not in self.conns.all.values(): raise KeyError(f"Connection {key} is not found in the model.") + conn_data = key + updates |= conn_data.set_differentiability(value) - updates |= key.set_differentiability(value) + if trace: + model_info = self.extract_model_key_index(conn_data) + self.assigned_differentiabilities.append((model_info, value)) model = self._get_outermost_parent() model.constraint_solver(updates) + def set_differentiability( + self, config: dict[ConnectionData, bool] | None = None, /, **kwargs: bool + ) -> None: + self._set_differentiability(config, True, **kwargs) + def _set_shapes( self, shapes: Mapping[ConnectionData, ShapeTemplateType] | None = None, @@ -1387,7 +1460,7 @@ def _set_shapes( **kwargs: ShapeTemplateType, ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. - assigned_shapes: dict[str, ShapeTemplateType] = {} + assigned_shapes: list[tuple[tuple[BaseModel, int], ShapeTemplateType]] = [] updates = Updates() if shapes is None: shapes = {} @@ -1406,7 +1479,11 @@ def _set_shapes( assert conn is not None inner_key = conn.key shape_nodes[key] = (given_repr.node, inner_key) - assigned_shapes[inner_key] = shape + # In order to store assigned shapes, we need to store corresponding model + # and index of the connection for that model. + model_info = self.extract_model_key_index(conn) + assigned_shapes.append((model_info, shape)) + # Apply updates to the shape nodes. for key in chain(shapes, kwargs): assert isinstance(key, str | ConnectionData) @@ -1444,12 +1521,6 @@ def _set_types( ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. if config is None: config = {} - - assigned_types: dict[ - str, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] = {} - # Get the outermost parent as all the updates will happen here. model = self._get_outermost_parent() updates = Updates() @@ -1458,12 +1529,11 @@ def _set_types( metadata = self.conns.extract_metadata(key) conn = self.conns.get_con_by_metadata(metadata) assert conn is not None - inner_key = conn.key - assigned_types[inner_key] = key_type updates |= metadata.set_type(key_type) - if trace: - # Store assigned types in the model. - self.assigned_types |= assigned_types + if trace: + # Store assigned types in the model. + model_info = self.extract_model_key_index(conn) + self.assigned_types.append((model_info, key_type)) # Run the constraints for updating affected connections. model.constraint_solver(updates) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index cc098ba3..70041bb1 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -536,7 +536,10 @@ def cin(self) -> Connection: return cin def _extend( - self, model: BaseModel, kwargs: Mapping[str, ConnectionType] | None = None + self, + model: BaseModel, + kwargs: Mapping[str, ConnectionType] | None = None, + trace: bool = True, ) -> Self: if kwargs is None: kwargs = {} @@ -560,7 +563,7 @@ def _extend( kwargs[key] = _value # type: ignore kwargs[key] = self._unroll_template(kwargs[key]) # type: ignore - self.extend(model, **kwargs) # type: ignore + self.extend(model, trace, **kwargs) # type: ignore return self def __add__(self, info: ExtendInfo | Model) -> Self: diff --git a/mithril/framework/logical/operator.py b/mithril/framework/logical/operator.py index 85cb9295..02d70a21 100644 --- a/mithril/framework/logical/operator.py +++ b/mithril/framework/logical/operator.py @@ -130,6 +130,8 @@ def class_name(self) -> str: def extend( self, model: BaseModel | BaseModel, + trace: bool = True, + /, **kwargs: ConnectionDataType, ) -> None: raise NotImplementedError("Operators cannot be extended!") diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index f543dc50..20ab841c 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -30,7 +30,7 @@ def __init__( name: str | None = None, ) -> None: super().__init__(name=name, enforce_jit=model._jittable) - self._extend(model, {k: k for k in model.external_keys}) + self._extend(model, {k: k for k in model.external_keys}, trace=False) self.expose_keys(*model.external_keys) @property @@ -42,11 +42,13 @@ def submodel(self) -> Operator: def extend( self, model: BaseModel, + trace: bool = True, + /, **kwargs: ConnectionDataType, ) -> None: if len(self.dag) > 0: raise RuntimeError("Primitive models cannot have submodels.") - super().extend(model, **kwargs) + super().extend(model, trace, **kwargs) class PrimitiveModel(OperatorModel): diff --git a/mithril/models/models.py b/mithril/models/models.py index e89260e8..a41395e4 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -1104,6 +1104,8 @@ def __init__( output=IOKey(name="output"), ) + # TODO: It is not clear where these "input1" and "input2" names come from. + # It assumes kernel model has two inputs named "input1" and "input2". shapes: dict[str, ShapeTemplateType] = { "input1": ["N", "d_in"], "input2": ["M", "d_in"], diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 18a2d6be..9630b57b 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -80,13 +80,13 @@ class RegDict(TypedDict): class ModelDict(TypedDict, total=False): name: str args: dict[str, Any] - assigned_shapes: dict[str, ShapeTemplateType] - differentiability_info: dict[str, bool] + assigned_shapes: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]] + assigned_types: list[tuple[tuple[str, int] | str, str]] + assigned_differentiabilities: list[tuple[tuple[str, int] | str, bool]] assigned_constraints: list[AssignedConstraintType] tuples: list[str] enums: dict[str, str] unnamed_keys: list[str] - types: dict[str, str] submodels: dict[str, ModelDict] connections: dict[str, dict[str, str | ConnectionDict]] canonical_keys: dict[str, tuple[set[str], set[str]]] @@ -125,6 +125,171 @@ class TrainModelDict(TypedDict): enum_dict = {"PaddingType": PaddingType} +def _extract_key_info( + model: BaseModel, submodel_dict: dict[BaseModel, str], info: tuple[BaseModel, int] +) -> tuple[str, int] | str: + m, key_index = info + m_name = submodel_dict[m] if m is not model else "self" + + key_name = list(model.external_keys)[key_index] + if m_name == "self" and not key_name.startswith("$"): + # If corresponding connection is named, save info + # only with this name and shape. + key_info: tuple[str, int] | str = key_name + else: + key_info = (m_name, key_index) + + return key_info + + +def _serialize_assigned_info( + model: BaseModel, submodel_dict: dict[BaseModel, str] +) -> tuple[ + list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]], + list[tuple[tuple[str, int] | str, str]], + list[tuple[tuple[str, int] | str, bool]], + list[AssignedConstraintType], +]: + shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]] = [] + types_info: list[tuple[tuple[str, int] | str, str]] = [] + differentiability_info: list[tuple[tuple[str, int] | str, bool]] = [] + constraints_info: list[AssignedConstraintType] = [] + + # Shapes conversion. + for shp_info in model.assigned_shapes: + info_list: list[tuple[tuple[str, int] | str, ShapeTemplateType]] = [] + for sub_info in shp_info: + (m, key_index), shape_info = sub_info + # Key info. + key_info = _extract_key_info(model, submodel_dict, (m, key_index)) + # Shape info. + shape_list: list[int | str | tuple[str, EllipsisType] | None] = [] + for item in shape_info: + if isinstance(item, tuple): # variadic + shape_list.append(f"{item[0]},...") + else: + shape_list.append(item) + # Combine key info with shape info. + info_list.append((key_info, shape_list)) + shapes_info.append(info_list) + + # Types conversion. + for type_info in model.assigned_types: + (m, key_index), typ = type_info + # Key info. + key_info = _extract_key_info(model, submodel_dict, (m, key_index)) + # Combine key info with type info. + if get_origin(typ) is Tensor: + types_info.append((key_info, "tensor")) + elif typ is not ToBeDetermined: + types_info.append((key_info, str(typ))) + + # Differentiability settings for models and keys. + for diff_info in model.assigned_differentiabilities: + (m, key_index), status = diff_info + # Key info. + key_info = _extract_key_info(model, submodel_dict, (m, key_index)) + # Combine key info with differentiability status. + differentiability_info.append((key_info, status)) + + # Constraints. + for constraint in model.assigned_constraints: + constraints_info.append(constraint) + + return shapes_info, types_info, differentiability_info, constraints_info + + +def _extract_connection_from_index( + model: BaseModel, model_name: str, index: int, submodel_dict: dict[str, BaseModel] +) -> ConnectionData: + _model = model if model_name == "self" else submodel_dict[model_name] + connection_key = list(_model.external_keys)[index] + connection = _model.conns.get_connection(connection_key) + assert connection is not None + return connection + + +def _set_assigned_info( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]], + types_info: list[tuple[tuple[str, int] | str, str]], + diffs_info: list[tuple[tuple[str, int] | str, bool]], + constraints_info: list[AssignedConstraintType], +) -> None: + # Shapes conversion. + for shp_info in shapes_info: + shapes: dict[ConnectionData, ShapeTemplateType] = {} + shape_kwargs: dict[str, ShapeTemplateType] = {} + for sub_info in shp_info: + shape = sub_info[1] + if isinstance(sub_info[0], tuple): + model_name, index = sub_info[0] + connection = _extract_connection_from_index( + model, model_name, index, submodel_dict + ) + shapes[connection] = shape + elif isinstance(sub_info[0], str): + name = sub_info[0] + shape_kwargs[name] = shape + else: + raise RuntimeError("Unknown shape info format") + model.set_shapes(shapes, **shape_kwargs) + + # Types conversion. + types_config = {} + types_kwargs = {} + for type_info in types_info: + typ: type + if type_info[1] == "tensor": + typ = Tensor[int | float | bool] + else: + typ = eval(type_info[1]) + + if isinstance(type_info[0], tuple): + model_name, index = type_info[0] + connection = _extract_connection_from_index( + model, model_name, index, submodel_dict + ) + types_config[connection] = typ + elif isinstance(type_info[0], str): + name = type_info[0] + types_kwargs[name] = typ + else: + raise RuntimeError("Unknown type info format!") + + model.set_types(types_config, **types_kwargs) + + # Differentiability settings for models and keys. + diff_config: dict[ConnectionData, bool] = {} + diff_kwargs: dict[str, bool] = {} + for diff_info in diffs_info: + status = diff_info[1] + if isinstance(diff_info[0], tuple): + model_name, index = diff_info[0] + connection = _extract_connection_from_index( + model, model_name, index, submodel_dict + ) + diff_config[connection] = status + elif isinstance(diff_info[0], str): + name = diff_info[0] + diff_kwargs[name] = status + else: + raise RuntimeError("Unknown differentiability info format!") + model.set_differentiability(diff_config, **diff_kwargs) + + # Constraints. + for constr_info in constraints_info: + constrain_fn = constr_info["fn"] + if constrain_fn not in constrain_fn_dict: + raise RuntimeError( + "In the process of creating a model from a dictionary, an unknown" + " constraint function was encountered!" + ) + constrain_fn = constrain_fn_dict[constrain_fn] + model.add_constraint(constrain_fn, keys=constr_info["keys"]) # type: ignore + + def create_iokey_kwargs( info: KeyDict, submodels_dict: dict[str, BaseModel] ) -> dict[str, Any]: @@ -207,27 +372,13 @@ def dict_to_model( } model = type(model_name, (Model,), attrs)(**args) - types: dict[str, str] = params.get("types", {}) - # TODO: Set all types in a bulk. - set_types = {} - for key, typ in types.items(): - if typ == "tensor": - set_types[key] = Tensor[int | float | bool] - # else: - # # TODO: Get rid of using eval method. Find more secure - # # way to convert strings into types and generic types. - # set_types[key] = eval(typ) - unnamed_keys: list[str] = params.get("unnamed_keys", []) - differentiability_info: dict[str, bool] = params.get("differentiability_info", {}) - assigned_shapes = params.get("assigned_shapes", {}) - assigned_constraints = params.get("assigned_constraints", []) canonical_keys: dict[str, tuple[set[str], set[str]]] = params.get( "canonical_keys", {} ) assert isinstance(model, Model) - submodels_dict = {} + submodels_dict: dict[str, BaseModel] = {} for m_key, v in submodels.items(): m = dict_to_model(v) submodels_dict[m_key] = m @@ -243,7 +394,7 @@ def dict_to_model( elif isinstance(conn, dict): if (io_key := conn.get("key")) is not None: # TODO: Update this part according to new IOKey structure. - key_kwargs = create_iokey_kwargs(io_key, submodels_dict) # type: ignore + key_kwargs = create_iokey_kwargs(io_key, submodels_dict) if (conns := key_kwargs.pop("connections", None)) is not None: mappings[k] = conns.pop() mergings[k] = conns @@ -258,33 +409,26 @@ def dict_to_model( con = getattr(m, key) model.merge_connections(con, *conns) - if set_types: - model.set_types(**set_types) - if "model" in canonical_keys: if canonical_keys["model"][0] is not None: model.set_cin(*canonical_keys["model"][0]) if canonical_keys["model"][1] is not None: model.set_cout(*canonical_keys["model"][1]) - for key, value in differentiability_info.items(): - con = model.conns.get_connection(key) - assert con is not None - con.set_differentiability(value) - - if len(assigned_constraints) > 0: - for constr_info in assigned_constraints: - constrain_fn = constr_info["fn"] - if constrain_fn not in constrain_fn_dict: - raise RuntimeError( - "In the process of creating a model from a dictionary, an unknown" - " constraint function was encountered!" - ) - constrain_fn = constrain_fn_dict[constrain_fn] - model.add_constraint(constrain_fn, keys=constr_info["keys"]) # type: ignore + # Set all assigned info. + assigned_types = params.get("assigned_types", []) + assigned_shapes = params.get("assigned_shapes", []) + assigned_differentiabilities = params.get("assigned_differentiabilities", []) + assigned_constraints = params.get("assigned_constraints", []) + _set_assigned_info( + model, + submodels_dict, + assigned_shapes, + assigned_types, + assigned_differentiabilities, + assigned_constraints, + ) - if len(assigned_shapes) > 0: - model.set_shapes(**dict_to_shape(assigned_shapes)) assert isinstance(model, Model) return model @@ -295,82 +439,64 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: model_name = model.__class__.__name__ args = handle_model_to_dict_args(model_name, model.factory_args) - assigned_shapes: dict[str, ShapeTemplateType] = {} - differentiability_info: dict[str, bool] = {} - assigned_constraints: list[AssignedConstraintType] = [] - types: dict[str, str] = {} - - for key, con in model.conns.all.items(): - edge = con.metadata - if edge.is_tensor and not con.is_autogenerated: - differentiability_info[key] = edge.differentiable - - for shape in model.assigned_shapes: - assigned_shapes |= shape_to_dict(shape) - for constrain in model.assigned_constraints: - assigned_constraints.append(constrain) + model_dict: ModelDict = { + "name": model_name, + "args": args, + } - for key, typ in model.assigned_types.items(): - if get_origin(typ) is Tensor: - types[key] = "tensor" - # elif typ is not ToBeDetermined: - # types[key] = str(typ) - - if ( - model_name != "Model" - and model_name in dir(models) - or model_name not in dir(models) - ): - model_dict: ModelDict = { - "name": model_name, - "args": args, - "assigned_shapes": assigned_shapes, - "differentiability_info": differentiability_info, - "assigned_constraints": assigned_constraints, - "types": types, - } - return model_dict + submodel_obj_dict: dict[BaseModel, str] = {} - connection_dict: dict[str, dict[str, str | ConnectionDict]] = {} - canonical_keys: dict[str, tuple[set[str], set[str]]] = {} - submodels: dict[str, ModelDict] = {} + if model_name == "Model" and model_name in dir(models): + connection_dict: dict[str, dict[str, str | ConnectionDict]] = {} + canonical_keys: dict[str, tuple[set[str], set[str]]] = {} + submodels: dict[str, ModelDict] = {} - # IOHyperEdge -> [model_id, connection_name] - submodel_connections: dict[IOHyperEdge, list[str]] = {} - assert isinstance(model, Model) + # IOHyperEdge -> [model_id, connection_name] + submodel_connections: dict[IOHyperEdge, list[str]] = {} + assert isinstance(model, Model) - for idx, submodel in enumerate(model.dag.keys()): - model_id = f"m_{idx}" - submodels[model_id] = model_to_dict(submodel) # type: ignore + for idx, submodel in enumerate(model.dag.keys()): + model_id = f"m_{idx}" + submodels[model_id] = model_to_dict(submodel) # type: ignore + submodel_obj_dict[submodel] = model_id - # Store submodel connections - for key in submodel._all_keys: - submodel_connections.setdefault( - submodel.conns.get_metadata(key), [model_id, key] + # Store submodel connections + for key in submodel._all_keys: + submodel_connections.setdefault( + submodel.conns.get_metadata(key), [model_id, key] + ) + assert isinstance(model, Model) + connection_dict[model_id] = connection_to_dict( + model, submodel, submodel_connections, model_id ) - assert isinstance(model, Model) - connection_dict[model_id] = connection_to_dict( - model, submodel, submodel_connections, model_id - ) - canonical_keys[model_id] = ( - get_keys(submodel.conns.cins), - get_keys(submodel.conns.couts), + canonical_keys[model_id] = ( + get_keys(submodel.conns.cins), + get_keys(submodel.conns.couts), + ) + canonical_keys["model"] = ( + get_keys(model.conns.cins), + get_keys(model.conns.couts), ) - canonical_keys["model"] = (get_keys(model.conns.cins), get_keys(model.conns.couts)) - composite_model_dict: ModelDict = { - "name": model_name, - "args": args, - "assigned_shapes": assigned_shapes, - "differentiability_info": differentiability_info, - "assigned_constraints": assigned_constraints, - "types": types, - "connections": connection_dict, - "canonical_keys": canonical_keys, - "submodels": submodels, - } - return composite_model_dict + model_dict |= { + "connections": connection_dict, + "canonical_keys": canonical_keys, + "submodels": submodels, + } + + ( + assigned_shapes, + assigned_types, + assigned_differentiabilities, + assigned_constraints, + ) = _serialize_assigned_info(model, submodel_obj_dict) + model_dict |= {"assigned_shapes": assigned_shapes} + model_dict |= {"assigned_types": assigned_types} + model_dict |= {"assigned_differentiabilities": assigned_differentiabilities} + model_dict |= {"assigned_constraints": assigned_constraints} + + return model_dict def get_keys(canonicals: set[ConnectionData]) -> set[str]: diff --git a/tests/json_files/integration_directed_test.json b/tests/json_files/integration_directed_test.json index 5232cd00..1235a9b4 100644 --- a/tests/json_files/integration_directed_test.json +++ b/tests/json_files/integration_directed_test.json @@ -2402,10 +2402,9 @@ "args": { "kernel": { "name": "PolynomialKernel", - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_types": [ + ["poly_coef", "tensor"] + ] } } }, @@ -2580,7 +2579,10 @@ "test_distance_matrix_1": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], "args": { "get_final_distance": false } @@ -2625,7 +2627,10 @@ "test_distance_matrix_2": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], "args": { "get_final_distance": false } @@ -2667,7 +2672,10 @@ "test_distance_matrix_3": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], "args": { "get_final_distance": false } @@ -2717,7 +2725,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1.0, 4.0, 9.0], @@ -2756,7 +2767,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1.0, 4.0], @@ -2792,7 +2806,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 3.0, 27.0, 48.0], @@ -2831,7 +2848,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 4.0], @@ -2864,7 +2884,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 4.0, 1.0], @@ -2901,7 +2924,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 25.0], @@ -2935,7 +2961,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1e+32], @@ -2969,7 +2998,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1e-32], @@ -3004,7 +3036,10 @@ "args": { "prediction_dim": 3, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "static_keys": { "input": [[1.0,1.0,0.0,2.0], [2.0,2.0,1.0,3.0]], @@ -3037,7 +3072,10 @@ "name": "TSNECore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1.0, 9.0], [1.0, 0.0, 4.0], [9.0, 4.0, 0.0]], @@ -3071,7 +3109,10 @@ }, "test_tsne_core_2": { "model": { - "name": "TSNECore" + "name": "TSNECore", + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]], @@ -3105,7 +3146,10 @@ }, "test_tsne_core_3": { "model": { - "name": "TSNECore" + "name": "TSNECore", + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 16.0], [16.0, 0.0]], @@ -3140,7 +3184,10 @@ "args": { "prediction_dim": 1, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "static_keys": { "input": [[5.0], [1.0]], @@ -3175,7 +3222,10 @@ "args": { "prediction_dim": 1, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "static_keys": { "input": [[0.0], [1.0], [3.0]], @@ -3209,11 +3259,11 @@ "test_polynomial_kernel_1": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input2", true], + ["poly_coef", true], + ["degree", true] + ] }, "inputs": { "input2": [[-1.0, 2.0], [1.0, 3.0]], @@ -3250,11 +3300,11 @@ "test_polynomial_kernel_2": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input1": true, "input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input1", true], + ["poly_coef", true], + ["degree", true] + ] }, "inputs": { "input1": [[-1.0, 2.0, 1.0], [1.0, 3.0, 2.0], [-2, 1, -1]], diff --git a/tests/json_files/models_directed_test.json b/tests/json_files/models_directed_test.json index 50ee0b79..c85bba6a 100644 --- a/tests/json_files/models_directed_test.json +++ b/tests/json_files/models_directed_test.json @@ -2,7 +2,7 @@ "test_linear_1": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { @@ -27,7 +27,7 @@ "test_linear_2": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -51,7 +51,7 @@ "test_linear_3": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2], [0.2, 0.3], [0.4, 0.3]], @@ -75,7 +75,7 @@ "test_linear_4": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2], [0.2, 0.5], [0.3, 0.1]], @@ -99,7 +99,7 @@ "test_layernorm_1": { "model": { "name": "LayerNorm", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2], [0.2, 0.5], [0.3, 0.1]], @@ -127,7 +127,7 @@ "test_layernorm_2": { "model": { "name": "LayerNorm", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2, 3.0], [-5, 0.2, 0.5], [0.3,12.0, 0.1]], @@ -155,7 +155,12 @@ "test_RBFKernel_1": { "model": { "name": "RBFKernel", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["l_scale", true], + ["sigma", true] + ] }, "inputs": { "input1": [[1.0], [2.0], [3.0]], @@ -186,7 +191,12 @@ "description": "We use this setting as input to test_SVMLayerWithReg_2", "model": { "name": "RBFKernel", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["l_scale", true], + ["sigma", true] + ] }, "inputs": { "input1": [[1.0], [2.0], [3.0]], @@ -216,11 +226,12 @@ "test_PolynomialKernel": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input1": true, "input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["poly_coef", true], + ["degree", true] + ] }, "inputs": { "input1": [[1.0, 2], [2, 3], [3, 4]], @@ -273,7 +284,7 @@ "test_logistic_1_logits_ignored": { "model": { "name": "LogisticRegression", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -298,7 +309,7 @@ "test_logistic_2_no_ignore": { "model": { "name": "LogisticRegression", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -324,10 +335,14 @@ "test_distance_matrix_1": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, "args": { "get_final_distance": false - } + }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0], [2], [3]], @@ -355,7 +370,11 @@ "test_distance_matrix_2": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0], [2], [3]], @@ -383,7 +402,11 @@ "test_distance_matrix_3": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0, 2.0], @@ -419,7 +442,11 @@ "test_distance_matrix_4": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ], "args": { "get_final_distance": false } @@ -451,7 +478,11 @@ "test_distance_matrix_5": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0], [3], [4]], @@ -480,7 +511,11 @@ "test_distance_matrix_6": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[3.0, 2], [4, 1], [4, 3]], @@ -515,7 +550,8 @@ "name": "PolynomialFeatures", "args": { "degree": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0, 3.0]] @@ -537,7 +573,8 @@ "name": "PolynomialFeatures", "args": { "degree": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2], [3], [4]] @@ -559,7 +596,8 @@ "name": "PolynomialFeatures", "args": { "degree": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 1], [2, 3], [3, -1], [0, 10]] @@ -587,7 +625,8 @@ "name": "PolynomialFeatures", "args": { "degree": 3 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2], [2, 1]] @@ -609,12 +648,13 @@ "test_mds_core_1": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -651,12 +691,13 @@ "test_mds_core_2": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -689,12 +730,13 @@ "test_mds_core_3": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -727,12 +769,13 @@ "test_mds_core_4": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -764,7 +807,11 @@ "NOTE": "Prove the differences of diagonal entries of input gradient wrt composite-ml is because of linearized models like power, log etc.", "model": { "name": "MDS", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [ + ["input", true], + ["coords", true], + ["norm", true] + ], "args": { "prediction_dim": 1, "input_type": "powered_distances" @@ -802,7 +849,9 @@ "test_tsne_core_1": { "model": { "name": "TSNECore", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [ + ["pred_distances", true] + ], "args": { "exact_distances": false } @@ -844,7 +893,10 @@ "name": "TSNECore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "inputs": { "pred_distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]] @@ -886,7 +938,10 @@ "name": "TSNECore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "inputs": { "pred_distances": [[0.0, 16], [16, 0]] @@ -922,7 +977,10 @@ "args": { "prediction_dim": 1, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "discard_keys": ["predicted_coords"], "inputs": { @@ -950,7 +1008,9 @@ "test_polynomial_1D_Degree_1": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [ + ["input", true] + ], "args": { "degree": 1, "dimension": 1 @@ -978,7 +1038,7 @@ "test_polynomial_2D_Degree_1": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 1, "dimension": 1 @@ -1006,7 +1066,7 @@ "test_polynomial_1D_Degree_2": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 2, "dimension": 1 @@ -1034,7 +1094,7 @@ "test_polynomial_1D_Degree_3": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 3, "dimension": 1 @@ -1062,7 +1122,7 @@ "test_polynomial_2D_Degree_2": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 2, "dimension": 1 @@ -1090,7 +1150,7 @@ "test_polynomial_2D_Degree_2_multilabel": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 2, "dimension": 1 @@ -2070,7 +2130,7 @@ "test_maxpool1d_1": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 } @@ -2101,7 +2161,7 @@ "test_maxpool1d_2": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2130,7 +2190,7 @@ "test_maxpool1d_3": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2162,7 +2222,7 @@ "test_maxpool1d_4": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [0,1] @@ -2195,7 +2255,7 @@ "test_maxpool2d_1": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 }, @@ -2237,7 +2297,7 @@ "test_maxpool2d_2": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 } @@ -2366,7 +2426,7 @@ "test_maxpool2d_3": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2403,7 +2463,7 @@ "test_maxpool2d_4": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2449,7 +2509,7 @@ "test_maxpool2d_5": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [[1,0],[0,1]] @@ -2493,7 +2553,8 @@ }, "test_where_1": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": { "cond": [[[ @@ -2542,7 +2603,8 @@ }, "test_where_2": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": { "cond": [[[ @@ -2591,7 +2653,8 @@ }, "test_where_3": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": {"cond": [true]}, "inputs": { @@ -2634,7 +2697,8 @@ }, "test_where_4": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": {"cond": [true]}, "inputs": { @@ -2673,7 +2737,8 @@ }, "test_where_5": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": {"cond": [false]}, "inputs": { @@ -2713,7 +2778,7 @@ "test_nn_1": { "model": { "name": "MLP", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "dimensions": [2, 2, 2], "activations": ["relu", "relu", "buffer"] @@ -2749,9 +2814,7 @@ "test_buffer_1": { "model": { "name": "Buffer", - "types": { - "input": "tensor" - } + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]] @@ -2773,9 +2836,7 @@ "model": { "name": "Buffer", - "types": { - "input": "tensor" - } + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [2.0, 0.0]] @@ -2797,9 +2858,7 @@ "model": { "name": "Buffer", - "types": { - "input": "tensor" - } + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [2.0, 0.0]] @@ -2821,7 +2880,8 @@ "test_matmul_1": { "model": { - "name": "MatrixMultiply" + "name": "MatrixMultiply", + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0, 5.0]], @@ -2844,10 +2904,7 @@ "test_mult_1": { "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0], [2.0], [3.0], [4.0]], @@ -2872,10 +2929,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [2.0], @@ -2900,10 +2954,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -2928,10 +2979,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -2956,10 +3004,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0], [3.0, 4.0]], @@ -2984,10 +3029,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3010,10 +3052,7 @@ "test_mult_7": { "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0],[2],[3]], @@ -3035,10 +3074,7 @@ "test_div_1": { "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [[1.0], [2.0], [3.0], [4.0]], @@ -3063,10 +3099,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [2.0], @@ -3091,10 +3124,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3119,10 +3149,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [1.0, 2.0], @@ -3147,10 +3174,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [[1.0, 2.0], [3.0, 4.0]], @@ -3173,10 +3197,7 @@ "test_div_6": { "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [1.0, 2.0], @@ -3199,10 +3220,7 @@ "test_sum_1": { "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0], [2.0], [3.0], [4.0]], @@ -3227,10 +3245,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [2.0], @@ -3255,10 +3270,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3283,10 +3295,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3311,10 +3320,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0], [3.0, 4.0]], @@ -3339,10 +3345,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3366,10 +3369,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[[2.0,3,4]]]]], @@ -3394,10 +3394,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[[[[[[[[[1.0, 1.0],[1.0, 1.0]]]]]]]]]]]], @@ -3422,10 +3419,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[[[1.0]]]],[[[[1.0]]]]], [[[[[1.0]]]], [[[[1.0]]]]]], @@ -3449,10 +3443,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0], [2.0], [3.0], [4.0]], @@ -3477,10 +3468,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [2.0], @@ -3505,10 +3493,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3533,10 +3518,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3561,10 +3543,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0], [3.0, 4.0]], @@ -3589,10 +3568,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3617,10 +3593,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1.0], [2.0], [3.0], [4.0]], @@ -3645,10 +3618,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [2.0], @@ -3673,10 +3643,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3701,10 +3668,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -3729,10 +3693,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1.0, 2.0], [3.0, 4.0]], @@ -3756,10 +3717,7 @@ "test_power_6": { "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -3782,7 +3740,8 @@ "test_exp": { "model": { - "name": "Exponential" + "name": "Exponential", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0]] @@ -3802,7 +3761,8 @@ }, "test_sqrt_1": { "model": { - "name": "Sqrt" + "name": "Sqrt", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0]] @@ -3822,7 +3782,8 @@ }, "test_sqrt_2": { "model": { - "name": "Sqrt" + "name": "Sqrt", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[4.0, 16.0], [25.0, 100.0]]]] @@ -3842,7 +3803,8 @@ }, "test_sqrt_3": { "model": { - "name": "Sqrt" + "name": "Sqrt", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [10000.0] @@ -3864,7 +3826,8 @@ "test_abs": { "model": { - "name": "Absolute" + "name": "Absolute", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, 0.0], [1.0, -2.0]] @@ -3889,7 +3852,8 @@ "test_sin": { "model": { - "name": "Sine" + "name": "Sine", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.5235987755982988, 1.0471975511965976], [1.0, 45.0], [90.0, 145.0]] @@ -3909,7 +3873,8 @@ "test_cos": { "model": { - "name": "Cosine" + "name": "Cosine", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.5235987755982988, 1.0471975511965976], [1.0, 45.0], [90.0, 145.0]] @@ -3951,7 +3916,8 @@ "test_flatten": { "model": { - "name": "Flatten" + "name": "Flatten", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] @@ -3971,7 +3937,8 @@ "test_flatten2": { "model": { - "name": "Flatten" + "name": "Flatten", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -3995,7 +3962,8 @@ "args": { "start_dim": 1, "end_dim": -1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4019,7 +3987,8 @@ "args": { "start_dim": -1, "end_dim": -1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4043,7 +4012,8 @@ "args": { "start_dim": 1, "end_dim": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4065,7 +4035,8 @@ "test_transpose": { "model": { - "name": "Transpose" + "name": "Transpose", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4090,7 +4061,8 @@ "args": { "axes": [0,2,1] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 3.0], [2.0, 4.0]], [[1.0, 3.0], [2.0, 4.0]], [[1.0, 3.0], [2.0, 4.0]]] @@ -4116,7 +4088,8 @@ "args": { "axes": [0] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [3.0] @@ -4141,7 +4114,8 @@ "args": { "axes": [1,4,3,2,0] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[2.0]]]], [[[[3.0]]]]] @@ -4166,7 +4140,8 @@ "args": { "axes": [2,4,3,0,1] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[2.0]]]], [[[[3.0]]]]] @@ -4286,7 +4261,8 @@ "test_tanh_1": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[10.0]] @@ -4308,7 +4284,8 @@ "test_tanh_2": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[30.0]] @@ -4330,7 +4307,8 @@ "test_tanh_3": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.220446049250313e-16]] @@ -4352,7 +4330,8 @@ "test_tanh_4": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4374,7 +4353,8 @@ "test_tanh_5": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4397,7 +4377,8 @@ "test_sigmoid_1": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[20.0]] @@ -4419,7 +4400,8 @@ "test_sigmoid_2": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4441,7 +4423,8 @@ "test_sigmoid_3": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-30.0]] @@ -4463,7 +4446,8 @@ "test_sigmoid_4": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[919.78546867]] @@ -4485,7 +4469,8 @@ "test_sigmoid_5": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-919.78546867]] @@ -4507,7 +4492,8 @@ "test_sigmoid_6": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4529,7 +4515,8 @@ "test_softplus_1": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0]] @@ -4551,7 +4538,8 @@ "test_softplus_2": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, 2.0, -30]] @@ -4573,7 +4561,8 @@ "test_softplus_3": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4595,7 +4584,8 @@ "test_softplus_4": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4616,7 +4606,8 @@ "test_permute_tensor_1": { "model": { - "name": "PermuteTensor" + "name": "PermuteTensor", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [ @@ -4654,7 +4645,8 @@ }, "test_permute_tensor_2": { "model": { - "name": "PermuteTensor" + "name": "PermuteTensor", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [ @@ -4692,7 +4684,11 @@ }, "test_squared_error_1": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [2], [3], [4]], @@ -4715,7 +4711,11 @@ "test_squared_error_2": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [-1000000.0], [0.000000000001], [-1.5]], @@ -4737,7 +4737,11 @@ }, "test_squared_error_3": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[0.1, 0.2], [1000000000.0, 0.0000000001]], @@ -4759,7 +4763,11 @@ }, "test_squared_error_4": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0, 2], [3, 4], [5, 6]], @@ -4781,7 +4789,8 @@ }, "test_hinge_loss_1": { "model": { - "name": "HingeLoss" + "name": "HingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0], [0.25]] @@ -4806,7 +4815,8 @@ "test_hinge_loss_2": { "model": { - "name": "HingeLoss" + "name": "HingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [-1.0]] @@ -4831,7 +4841,8 @@ }, "test_hinge_loss_3": { "model": { - "name": "HingeLoss" + "name": "HingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0], [0.0]] @@ -4853,7 +4864,11 @@ }, "test_absolute_error_1": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -4882,7 +4897,11 @@ }, "test_absolute_error_2": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [-1000000.0], [0.000000000001], [-1.5]], @@ -4912,7 +4931,11 @@ "test_absolute_error_3": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[0.1, 0.2], [1000000000.0, 0.0000000001]], @@ -4942,7 +4965,11 @@ "test_absolute_error_4": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], @@ -4972,7 +4999,8 @@ "test_cross_entropy_1": { "model": { - "name": "CrossEntropy" + "name": "CrossEntropy", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 1.0, 1.0, 1.0]] @@ -4995,7 +5023,8 @@ "test_cross_entropy_2": { "model": { - "name": "CrossEntropy" + "name": "CrossEntropy", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1000.0, 0.0], [0.0, 1000.0]] @@ -5018,7 +5047,8 @@ "test_cross_entropy_3": { "model": { - "name": "CrossEntropy" + "name": "CrossEntropy", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 1.0], [0.0, 1000.0]] @@ -5043,7 +5073,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] @@ -5070,7 +5101,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] @@ -5101,7 +5133,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[2.220446049250313e-16, 1.0], [0.1, 0.9]] @@ -5128,7 +5161,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0]] @@ -5159,7 +5193,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.2, 1.1102230246251565e-16, 0.7999999999999998], [0.1, 0.6, 0.3]] @@ -5185,7 +5220,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.5, 0.5], [0.1, 0.9]] @@ -5212,7 +5248,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.5, 0.5], [0.1, 0.9]] @@ -5239,7 +5276,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] @@ -5265,7 +5303,8 @@ "name": "CrossEntropy", "args": { "input_type": "log_probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[-0.6931471805599453, -0.6931471805599453], [-2.3025850929940455, -0.10536051565782628]] @@ -5291,7 +5330,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.1], [0.5]] @@ -5317,7 +5357,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.1, 0.2, 0.3], [0.5, 0.4, 0.2]] @@ -5347,7 +5388,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1e-20, 0.2, 0.3], [0.5, 0.4, 1e-20]] @@ -5377,7 +5419,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2, 0.3], [0.5, 1.0, 0.2]] @@ -5408,7 +5451,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.1], [0.5]] @@ -5434,7 +5478,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.1102230246251565e-16], [0.5]] @@ -5460,7 +5505,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0102230246251565e-16], [0.5]] @@ -5486,7 +5532,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [ @@ -5527,7 +5574,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.1102230246251565e-16], [0.95]] @@ -5553,7 +5601,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-36.83119036496738], [0.0]] @@ -5583,7 +5632,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0], [-2]] @@ -5609,7 +5659,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0, -2, 3, 0], [0, 1, 2, -1]] @@ -5641,7 +5692,8 @@ }, "test_quantile_loss_1": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0]] @@ -5664,7 +5716,8 @@ }, "test_quantile_loss_2": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "static_keys": { "quantile": 0.1, @@ -5687,7 +5740,8 @@ }, "test_quantile_loss_3": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [1e-5]] @@ -5710,7 +5764,8 @@ }, "test_quantile_loss_4": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0, 2.0], [0.0, 1.0]] @@ -5729,7 +5784,8 @@ }, "test_quad_hinge_loss_1": { "model": { - "name": "QuadHingeLoss" + "name": "QuadHingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[2.0], [0.25]] @@ -5751,7 +5807,8 @@ }, "test_quad_hinge_loss_2": { "model": { - "name": "QuadHingeLoss" + "name": "QuadHingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0], [-1.0]] @@ -5774,7 +5831,8 @@ "test_quad_hinge_loss_3": { "model": { - "name": "QuadHingeLoss" + "name": "QuadHingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0], [0.0]] @@ -5796,7 +5854,8 @@ }, "test_kl_div_1": { "model": { - "name": "KLDivergence" + "name": "KLDivergence", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.5], [0.5, 0.1]] @@ -5818,7 +5877,8 @@ }, "test_kl_div_3": { "model": { - "name": "KLDivergence" + "name": "KLDivergence", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.5], [0.5, 0.1]] @@ -5840,7 +5900,8 @@ }, "test_kl_div_4": { "model": { - "name": "KLDivergence" + "name": "KLDivergence", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.1102230246251565e-16, 0.5], [0.5, 0.1]] @@ -5862,7 +5923,8 @@ }, "test_relu_1": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -5881,7 +5943,8 @@ }, "test_relu_2": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0]] @@ -5900,7 +5963,8 @@ }, "test_relu_3": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -5919,7 +5983,8 @@ }, "test_relu_4": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -5938,7 +6003,8 @@ }, "test_leaky_relu_1": { "model": { - "name": "LeakyRelu" + "name": "LeakyRelu", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"slope": 0.2}, "inputs": { @@ -5962,7 +6028,8 @@ }, "test_leaky_relu_2": { "model": { - "name": "LeakyRelu" + "name": "LeakyRelu", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"slope": 0.2}, "inputs": { @@ -5986,7 +6053,8 @@ }, "test_leaky_relu_3": { "model": { - "name": "LeakyRelu" + "name": "LeakyRelu", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"slope": 0.2}, "inputs": { @@ -6010,7 +6078,8 @@ }, "test_gelu_1": { "model": { - "name": "Gelu" + "name": "Gelu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6029,7 +6098,8 @@ }, "test_gelu_2": { "model": { - "name": "Gelu" + "name": "Gelu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -6048,7 +6118,8 @@ }, "test_stop_gradient_1": { "model": { - "name": "StopGradient" + "name": "StopGradient", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6067,7 +6138,8 @@ }, "test_stop_gradient_2": { "model": { - "name": "StopGradient" + "name": "StopGradient", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -6086,7 +6158,8 @@ }, "test_softmax_1": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, 2.0, -30]] @@ -6105,7 +6178,8 @@ }, "test_softmax_2": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.0, 0.0]] @@ -6124,7 +6198,8 @@ }, "test_softmax_3": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6143,7 +6218,8 @@ }, "test_softmax_4": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, 2.0, -30]] @@ -6162,7 +6238,8 @@ }, "test_softmax_5": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.0, 0.0]] @@ -6181,7 +6258,8 @@ }, "test_softmax_6": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6200,7 +6278,8 @@ }, "test_softmax_7": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6219,7 +6298,8 @@ }, "test_softmax_8": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"axis": 0}, "inputs": { @@ -6242,7 +6322,8 @@ "name": "Log", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0, 2.0], [3.0, 4.0], [4.0, 100.0]] @@ -6264,7 +6345,8 @@ "name": "Log", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[0.0]]] @@ -6290,7 +6372,8 @@ "name": "Log", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1e-311, 1e-306]]] @@ -6313,7 +6396,8 @@ }, "test_stable_reciprocal_1": { "model": { - "name": "StableReciprocal" + "name": "StableReciprocal", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[2.0, 2.0], [3.0, 4.0], [4.0, 100.0]]]]] @@ -6332,7 +6416,8 @@ }, "test_stable_reciprocal_2": { "model": { - "name": "StableReciprocal" + "name": "StableReciprocal", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0]] @@ -6351,7 +6436,8 @@ }, "test_stable_reciprocal_3": { "model": { - "name": "StableReciprocal" + "name": "StableReciprocal", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1e-155, 1e-145]] @@ -6373,7 +6459,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0]] @@ -6395,7 +6482,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[4.0, 16.0], [25.0, 100.0]]]] @@ -6417,7 +6505,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [10000.0] @@ -6439,7 +6528,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, -4.0], [-1.0, -4.0]] @@ -6465,7 +6555,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"cutoff": 1e-20}, "inputs": { @@ -6493,10 +6584,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[0.0], [2.2250738585072014e-308]], @@ -6521,10 +6609,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[2.0], [3.0], [4.0]], @@ -6553,10 +6638,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [2.0], @@ -6587,10 +6669,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [2.0], @@ -6615,10 +6694,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -6644,10 +6720,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -6673,10 +6746,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1e-311, 1e-311]], @@ -6910,7 +6980,8 @@ }, "test_var_1": { "model": { - "name": "Variance" + "name": "Variance", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -6933,7 +7004,8 @@ "args": { "correction": 0.0, "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -6956,7 +7028,8 @@ "args": { "correction": 0.0, "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -6978,7 +7051,8 @@ "name": "Variance", "args": { "correction": 1.0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -6997,7 +7071,8 @@ }, "test_reduce_sum_1": { "model": { - "name": "Sum" + "name": "Sum", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7016,7 +7091,8 @@ }, "test_reduce_sum_2": { "model": { - "name": "Sum" + "name": "Sum", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 3.0], [4.0, 5.0]] @@ -7038,7 +7114,8 @@ "name": "Sum", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7060,7 +7137,8 @@ "name": "Sum", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7082,7 +7160,8 @@ "name": "Sum", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7105,7 +7184,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7128,7 +7208,8 @@ "args": { "axis": [0,1,2,3,4] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]] @@ -7147,7 +7228,8 @@ }, "test_reduce_mean_1": { "model": { - "name": "Mean" + "name": "Mean", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7166,7 +7248,8 @@ }, "test_reduce_mean_2": { "model": { - "name": "Mean" + "name": "Mean", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 3.0], [4.0, 5.0]] @@ -7188,7 +7271,8 @@ "name": "Mean", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7210,7 +7294,8 @@ "name": "Mean", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7232,7 +7317,8 @@ "name": "Mean", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7255,7 +7341,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7278,7 +7365,8 @@ "args": { "axis": [0,1,2,3,4] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]] @@ -7297,7 +7385,8 @@ }, "test_reduce_max_1": { "model": { - "name": "Max" + "name": "Max", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7316,7 +7405,8 @@ }, "test_reduce_max_2": { "model": { - "name": "Max" + "name": "Max", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 6.0], [6.0, 5.0]] @@ -7338,7 +7428,8 @@ "name": "Max", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7360,7 +7451,8 @@ "name": "Max", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7382,7 +7474,8 @@ "name": "Max", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7405,7 +7498,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7428,7 +7522,8 @@ "args": { "axis": [0,1,2,3,4] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]] @@ -7447,7 +7542,8 @@ }, "test_reduce_min_1": { "model": { - "name": "Min" + "name": "Min", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7466,7 +7562,8 @@ }, "test_reduce_min_2": { "model": { - "name": "Min" + "name": "Min", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 6.0], [6.0, 5.0]] @@ -7488,7 +7585,8 @@ "name": "Min", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7510,7 +7608,8 @@ "name": "Min", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7532,7 +7631,8 @@ "name": "Min", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7555,7 +7655,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7578,7 +7679,8 @@ "args": { "axis": [0,1,3,5,6] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]]]] @@ -7623,7 +7725,8 @@ "right": {"key": {"connect": [["m2","output"]]}}, "output": "output" } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -7760,7 +7863,8 @@ "expose": true }} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -7971,7 +8075,8 @@ "expose": true }} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0, 3.0], @@ -8153,7 +8258,8 @@ "expose": true }} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -8278,7 +8384,8 @@ "right": {"key": {"name": "right", "type": "tensor"}}, "output": {"key": {"connect": [["m2","right"]]}} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -8334,7 +8441,8 @@ } - } + }, + "assigned_differentiabilities": [["input1", true]] }, "inputs": { "input1": [[[[1.0, 2.0, 3.0]]]] @@ -8385,7 +8493,8 @@ "m1": { "input": {"key": {"connect": [["m2","input"]]}} } - } + }, + "assigned_differentiabilities": [["my_input", true]] }, "inputs": { "my_input": [[[[1.0, 2.0, 3.0]]]] @@ -8436,7 +8545,8 @@ "m1": { "input": {"key": {"connect": [["m2","input"]]}} } - } + }, + "assigned_differentiabilities": [["my_input", true]] }, "inputs": { "my_input": [[[[1.0, 2.0, 3.0]]]] @@ -8579,7 +8689,10 @@ "test_rnn_cell_1": { "model": { "name": "RNNCell", - "differentiability_info": {"input": true, "prev_hidden": true} + "assigned_differentiabilities": [ + ["input", true], + ["prev_hidden", true] + ] }, "inputs": { @@ -8618,7 +8731,10 @@ "test_rnn_cell_2": { "model": { "name": "RNNCell", - "differentiability_info": {"input": true, "prev_hidden": true} + "assigned_differentiabilities": [ + ["input", true], + ["prev_hidden", true] + ] }, "inputs": { @@ -8657,7 +8773,11 @@ "test_lstm_cell_1": { "model": { "name": "LSTMCell", - "differentiability_info": {"input": true, "prev_hidden": true, "prev_cell": true} + "assigned_differentiabilities": [ + ["input", true], + ["prev_hidden", true], + ["prev_cell", true] + ] }, "inputs": { diff --git a/tests/json_files/no_grad_inference_test.json b/tests/json_files/no_grad_inference_test.json index c99b9e1e..474fb0c9 100644 --- a/tests/json_files/no_grad_inference_test.json +++ b/tests/json_files/no_grad_inference_test.json @@ -19,6 +19,11 @@ "name": "Multiply" } }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["input3", true] + ], "connections": { "m1": { "left": {"key": {"name": "input1", "type": "tensor"}}, @@ -96,6 +101,11 @@ "name": "Multiply" } }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["input3", true] + ], "connections": { "m1": { "left": {"key": {"name": "input1", "type": "tensor"}}, @@ -172,6 +182,11 @@ "name": "Multiply" } }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["input3", true] + ], "connections": { "m1": { "left": {"key": {"name": "input1", "type": "tensor"}}, diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index 2492a775..3afcfe71 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -1631,7 +1631,10 @@ "test_one_to_many_rnn": { "model": { "name": "OneToMany", - "differentiability_info": {"input": true, "initial_hidden": true}, + "assigned_differentiabilities": [ + ["input", true], + ["initial_hidden", true] + ], "regular_args": { "max_sequence_length": 4, "cell_type": "RNNCell" @@ -2070,11 +2073,13 @@ "test_polynomial_kernel": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input1": true, "input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], + "assigned_types": [ + ["poly_coef", "tensor"] + ] }, "static_input_info": { "poly_coef": { diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 7f8cfa16..516f3088 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -59,17 +59,15 @@ def evaluate_case( else: is_list = static_keys.get("is_list", False) static_keys[key] = convert_to_array(backend, value, is_list) + # Set logical tensor values for trainable keys if and only if - # model has any output keys. - if model.output_keys: + # model has list type inputs and has any output keys. + if is_inputs_list and model.output_keys: values: dict[str, Tensor[float] | list[Tensor[float]]] = {} for key, value in reference_gradients.items(): - if is_inputs_list: - values[key] = [ - Tensor[float](differentiable=True) for _ in range(len(value)) - ] - else: - values[key] = Tensor[float](differentiable=True) + values[key] = [ + Tensor[float](differentiable=True) for _ in range(len(value)) + ] model.set_values(**values) models.append(model) diff --git a/tests/scripts/summary_txts/test_physical_summary_17 b/tests/scripts/summary_txts/test_physical_summary_17 index 78c25e8e..873e0802 100644 --- a/tests/scripts/summary_txts/test_physical_summary_17 +++ b/tests/scripts/summary_txts/test_physical_summary_17 @@ -18,16 +18,16 @@ Total Parameters : >0 ------------------------------- - MatrixMultiply ----------------------------------------------------------------------------------------------------------- -Model Name | Model Keys - | --------------------------------------------------------------------------- - | Keys : Shapes : Types : Connections : Parameters -========================================================================================================== -MatrixMultiply | Inputs : left : [None, ..., None] : bool | float | int : 'input' : Unknown - | right : [None, 3] : float : 'output_0' : 0 - | ------------------------------------------------------------------------------------- - | Outputs : output : [None, ..., 3] : float : 'output_1' : 0 ----------------------------------------------------------------------------------------------------------- + MatrixMultiply +--------------------------------------------------------------------------------------------- +Model Name | Model Keys + | -------------------------------------------------------------- + | Keys : Shapes : Types : Connections : Parameters +============================================================================================= +MatrixMultiply | Inputs : left : [None, ..., None] : float : 'input' : Unknown + | right : [None, 3] : float : 'output_0' : 0 + | ------------------------------------------------------------------------ + | Outputs : output : [None, ..., 3] : float : 'output_1' : 0 +--------------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth b/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth index b5a14a3c..4fabd9ca 100644 --- a/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth +++ b/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth @@ -4,8 +4,8 @@ Model Name | Model Keys | -------------------------------------------------------------- | Keys : Shapes : Types : Connections ============================================================================================ -KernelizedSVM | Inputs : input1 : [u1, u2] : bool | float | int : '$input' - | input2 : [u3, u2] : bool | float | int : '$input2' +KernelizedSVM | Inputs : input1 : [u1, u2] : float : '$input' + | input2 : [u3, u2] : float : '$input2' | sigma : [ 1] : bool | float | int : '$sigma' | l_scale : [ 1] : bool | float | int : '$l_scale' | weight : [ 1, u3] : float : '$weight' @@ -34,13 +34,13 @@ Model Name | Model Keys | ---------------------------------------------------------- | Keys : Shapes : Types : Connections ===================================================================================== -RBFKernel | Inputs : input1 : [u1, u2] : bool | float | int : 'input1' - | input2 : [u3, u2] : bool | float | int : 'input2' +RBFKernel | Inputs : input1 : [u1, u2] : float : 'input1' + | input2 : [u3, u2] : float : 'input2' | $right : -- : float : -0.5 | sigma : [ 1] : bool | float | int : 'sigma' | l_scale : [ 1] : bool | float | int : 'l_scale' | -------------------------------------------------------------------- - | Outputs : output : [u1, u3] : bool | float | int : Linear.input + | Outputs : output : [u1, u3] : float : Linear.input | 'kernel' ------------------------------------------------------------------------------------- Linear | Inputs : weight : [ 1, u3] : float : 'weight' diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index c99aa422..c6e81546 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -144,8 +144,7 @@ def compile_and_compare( pm = mithril.compile( model, backend=backend, - **compile_kwargs - | {"constant_keys": statics, "trainable_keys": params.keys()}, + **compile_kwargs | {"constant_keys": statics}, ) outputs = pm.evaluate(params=backend_params, data=backend_data) @@ -267,7 +266,7 @@ def test_buffer_1(): def test_buffer_2(): model = Buffer() - model.set_types(input=Tensor) + model.set_differentiability(input=True) params = {"input": [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]} output_gradients = {"output": [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]]} reference_outputs = {"output": [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]} @@ -395,7 +394,7 @@ def test_isnan_2(): def test_nan_to_num_1(): - model = NanToNum() + model = NanToNum(input=Tensor(differentiable=True)) params = {"input": [[1.0, float("nan"), 3.0], [1.0, 2.0, float("nan")]]} output_gradients = {"output": [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]} reference_outputs = {"output": [[1.0, 0.0, 3.0], [1.0, 2.0, 0.0]]} @@ -668,7 +667,9 @@ def test_conv1d_6(): def test_where_1(): - model = Where() + model = Where( + input1=Tensor(differentiable=True), input2=Tensor(differentiable=True) + ) data = { "cond": [True, True, False, False, True], @@ -699,7 +700,9 @@ def test_where_1(): def test_where_2(): - model = Where() + model = Where( + input1=Tensor(differentiable=True), input2=Tensor(differentiable=True) + ) data = { "cond": [True, True, False, False, True], @@ -1242,7 +1245,7 @@ def test_to_tensor(): def test_reduce_prod_1(): - model = Prod(axis=None) + model = Prod(input=Tensor(differentiable=True), axis=None) params = {"input": [1.0, 2.0, 3.0, 4.0, 5.0]} output_gradients = {"output": 1.0} reference_outputs = {"output": 120.0} @@ -1262,7 +1265,7 @@ def test_reduce_prod_1(): def test_reduce_prod_2(): - model = Prod(axis=None) + model = Prod(input=Tensor(differentiable=True), axis=None) params = {"input": [1.0, 0.0, 3.0, 4.0, 5.0]} output_gradients = {"output": 1.0} reference_outputs = {"output": 0.0} @@ -1282,7 +1285,7 @@ def test_reduce_prod_2(): def test_reduce_prod_3(): - model = Prod(axis=None) + model = Prod(input=Tensor(differentiable=True), axis=None) params = {"input": [1.0, 0.0, 3.0, 0.0, 5.0]} output_gradients = {"output": 12.0} reference_outputs = {"output": 0.0} @@ -1302,7 +1305,7 @@ def test_reduce_prod_3(): def test_reduce_prod_4(): - model = Prod(axis=1) + model = Prod(input=Tensor(differentiable=True), axis=1) params = {"input": [[1.0, 2.0], [3.0, 4.0]]} output_gradients = {"output": [2.0, 3.0]} reference_outputs = {"output": [2.0, 12.0]} @@ -1322,7 +1325,7 @@ def test_reduce_prod_4(): def test_reduce_prod_5(): - model = Prod(axis=(1, 2)) + model = Prod(input=Tensor(differentiable=True), axis=(1, 2)) params = { "input": [ [[1.0, 2.0], [3.0, 4.0]], @@ -1358,7 +1361,7 @@ def test_reduce_prod_5(): def test_reduce_prod_6(): - model = Prod(axis=(1, 2), keepdim=True) + model = Prod(input=Tensor(differentiable=True), axis=(1, 2), keepdim=True) params = { "input": [ [[1.0, 2.0], [3.0, 4.0]], @@ -1642,7 +1645,7 @@ def test_eye_complement_w_dtype(): def test_squeeze_1(): - model = Squeeze() + model = Squeeze(input=Tensor(differentiable=True)) params = {"input": list_full(1.0, 3, 1, 4, 2, 1)} output_gradients = {"output": list_full(1.0, 3, 4, 2)} reference_outputs = {"output": list_full(1.0, 3, 4, 2)} @@ -1662,7 +1665,7 @@ def test_squeeze_1(): def test_squeeze_2(): - model = Squeeze() + model = Squeeze(input=Tensor(differentiable=True)) params = {"input": list_full(1.0, 3, 1, 4, 2, 1, 1, 1, 5)} output_gradients = {"output": list_full(1.0, 3, 4, 2, 5)} reference_outputs = {"output": list_full(1.0, 3, 4, 2, 5)} @@ -1682,7 +1685,7 @@ def test_squeeze_2(): def test_broadcast_to_1(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": list_full(1.0, 1, 1)} output_gradients = {"output": list_full(1.0, 3, 3)} reference_outputs = {"output": list_full(1.0, 3, 3)} @@ -1703,7 +1706,7 @@ def test_broadcast_to_1(): def test_broadcast_to_2(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [4.0]} output_gradients = {"output": [[3.0, 4.0], [5.0, 6.0]]} reference_outputs = {"output": [[4.0, 4.0], [4.0, 4.0]]} @@ -1724,7 +1727,7 @@ def test_broadcast_to_2(): def test_broadcast_to_3(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [[1.0], [7.0]]} output_gradients = {"output": [[3.0, 4.0], [5.0, 6.0]]} reference_outputs = {"output": [[1.0, 1.0], [7.0, 7.0]]} @@ -1745,7 +1748,7 @@ def test_broadcast_to_3(): def test_broadcast_to_4(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [1.0]} output_gradients = {"output": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]} reference_outputs = {"output": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]} @@ -1766,7 +1769,7 @@ def test_broadcast_to_4(): def test_broadcast_to_5(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [[1.0, 2.0], [3.0, 4.0]]} output_gradients = { "output": [ @@ -1803,7 +1806,7 @@ def test_broadcast_to_5(): def test_norm_modifier_1(): - model = NormModifier() + model = NormModifier(input=Tensor(differentiable=True)) params = {"input": 3.0} output_gradients = {"output": 2.0} reference_outputs = {"output": 3.0} @@ -1823,7 +1826,7 @@ def test_norm_modifier_1(): def test_norm_modifier_2(): - model = NormModifier() + model = NormModifier(input=Tensor(differentiable=True)) params = {"input": 6.0} output_gradients = {"output": 2.0} reference_outputs = {"output": 4.0} @@ -1843,7 +1846,7 @@ def test_norm_modifier_2(): def test_norm_modifier_3(): - model = NormModifier() + model = NormModifier(input=Tensor(differentiable=True)) params = {"input": -1.0} output_gradients = {"output": 2.0} reference_outputs = {"output": 3.0} @@ -1916,6 +1919,7 @@ def test_size_3(): def test_scaled_dot_product_1(): model = ScaledDotProduct() + model.set_differentiability(query=True, key=True, value=True) params = {"query": [[1.0]], "key": [[1.0]], "value": [[1.0]]} output_gradients = {"output": [[1.0]]} reference_outputs = {"output": [[1.0]]} @@ -1936,6 +1940,7 @@ def test_scaled_dot_product_1(): def test_scaled_dot_product_2(): model = ScaledDotProduct() + model.set_differentiability(query=True, key=True, value=True) params = { "query": [[1.0, 1.0], [1.0, 1.0]], "key": [[1.0, 1.0], [1.0, 1.0]], @@ -2084,7 +2089,7 @@ def test_slice_4(): def test_log_1(): - model = Log() + model = Log(input=Tensor(differentiable=True)) params = {"input": [3.0]} output_gradients = {"output": [1.0]} @@ -2107,7 +2112,7 @@ def test_log_1(): def test_log_2(): - model = Log() + model = Log(input=Tensor(differentiable=True)) params = {"input": [3.0, 2.0]} output_gradients = {"output": [1.0, 1.0]} diff --git a/tests/scripts/test_differentiablity.py b/tests/scripts/test_differentiablity.py index 431296b5..ada402c5 100644 --- a/tests/scripts/test_differentiablity.py +++ b/tests/scripts/test_differentiablity.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import mithril from mithril import JaxBackend from mithril.framework.common import Tensor @@ -96,7 +98,6 @@ def test_set_diff_data_and_param(): def test_match_tensor_with_value_data_and_param(): model1 = Multiply() model1.set_types(left=Tensor) - model1.set_differentiability(left=False) assert not model1.left.metadata.differentiable @@ -112,23 +113,44 @@ def test_match_tensor_with_value_data_and_param(): assert model.my_input.metadata.differentiable # type: ignore -def test_match_tensor_with_value_data_and_param_rev(): +def test_match_tensor_with_value_data_and_param_error(): + model1 = Multiply() + model1.set_types(left=Tensor) + model1.set_differentiability(left=False) + + assert not model1.left.metadata.differentiable + model2 = Multiply() model2.set_types(left=Tensor) model2.set_differentiability(left=True) assert model2.left.metadata.differentiable + model = Model() + model += model1(left="my_input") + with pytest.raises(ValueError) as err_info: + model += model2(left="my_input") + assert str(err_info.value) == "Differentiability mismatch!" + + +def test_match_tensor_with_value_data_and_param_error_rev(): model1 = Multiply() model1.set_types(left=Tensor) - model1.set_differentiability(left=False) + model1.set_differentiability(left=True) - assert not model1.left.metadata.differentiable + assert model1.left.metadata.differentiable + + model2 = Multiply() + model2.set_types(left=Tensor) + model2.set_differentiability(left=False) + + assert not model2.left.metadata.differentiable model = Model() model += model1(left="my_input") - model += model2(left="my_input") - assert model.my_input.metadata.differentiable # type: ignore + with pytest.raises(ValueError) as err_info: + model += model2(left="my_input") + assert str(err_info.value) == "Differentiability mismatch!" def test_diff_inference(): diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index f042846b..890924d7 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -958,12 +958,12 @@ def test_valued_scalar_in_init(): outer_model = Model() outer_model |= model() - model_dict_created = dict_conversions.model_to_dict(model) + model_dict_created = dict_conversions.model_to_dict(outer_model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated - assert_models_equal(model, model_recreated) + assert_models_equal(outer_model, model_recreated) def test_valued_scalar_in_extend(): @@ -973,12 +973,12 @@ def test_valued_scalar_in_extend(): outer_model = Model() outer_model |= model() - model_dict_created = dict_conversions.model_to_dict(model) + model_dict_created = dict_conversions.model_to_dict(outer_model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated - assert_models_equal(model, model_recreated) + assert_models_equal(outer_model, model_recreated) def test_valued_scalar_iokey(): @@ -990,12 +990,12 @@ def test_valued_scalar_iokey(): outer_model = Model() outer_model |= model(axis=IOKey(name="axis", value=1)) - model_dict_created = dict_conversions.model_to_dict(model) + model_dict_created = dict_conversions.model_to_dict(outer_model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated - assert_models_equal(model, model_recreated) + assert_models_equal(outer_model, model_recreated) def test_non_valued_scalar(): @@ -1005,9 +1005,84 @@ def test_non_valued_scalar(): outer_model = Model() outer_model |= model() + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + +def test_assigned_shapes(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + model.set_shapes(buff_input=[1, 2, ("V", ...)], mean_input=[("V", ...), 3, 4]) + + model_dict_created = dict_conversions.model_to_dict(model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(model, model_recreated) + + assert ( + model_dict_created.get("assigned_shapes") + == model_dict_recreated.get("assigned_shapes") + == [ + [("buff_input", [1, 2, "V,..."]), ("mean_input", ["V,...", 3, 4])], + ] + ) + + +def test_assigned_types_1(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + model.set_types(mean_input=Tensor[int | float]) + model_dict_created = dict_conversions.model_to_dict(model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated assert_models_equal(model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [ + ("mean_input", "tensor"), + ] + ) + + +def test_assigned_types_2(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + model.set_types(mean_input=Tensor[int | float]) + + outer_model = Model() + outer_model |= model + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [] + ) + + assert ( + model_dict_created["submodels"]["m_0"].get("assigned_types") # type: ignore + == model_dict_recreated["submodels"]["m_0"].get("assigned_types") # type: ignore + == [ + ("mean_input", "tensor"), + ] + ) From f3a10a65b9cdf5a9721ecc7bc85e362e0ed1becd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Wed, 5 Mar 2025 13:20:30 +0300 Subject: [PATCH 2/7] Assigned info data structure changed from list to dict. --- mithril/framework/logical/base.py | 24 ++++++++++-------------- mithril/utils/dict_conversions.py | 15 ++++++--------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 951274f2..81c40d8b 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -374,16 +374,12 @@ def __init__( self._formula_key: str | None = formula_key # TODO: maybe set it only to Operator / Model. self.parent: BaseModel | None = None - self.assigned_shapes: list[ - list[tuple[tuple[BaseModel, int], ShapeTemplateType]] - ] = [] - self.assigned_types: list[ - tuple[ - tuple[BaseModel, int], - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] - ] = [] - self.assigned_differentiabilities: list[tuple[tuple[BaseModel, int], bool]] = [] + self.assigned_shapes: list[dict[tuple[BaseModel, int], ShapeTemplateType]] = [] + self.assigned_types: dict[ + tuple[BaseModel, int], + type | UnionType | ScalarType | type[Tensor[int | float | bool]], + ] = {} + self.assigned_differentiabilities: dict[tuple[BaseModel, int], bool] = {} self.assigned_constraints: list[AssignedConstraintType] = [] self.conns = Connections() self.frozen_attributes: list[str] = [] @@ -1442,7 +1438,7 @@ def _set_differentiability( if trace: model_info = self.extract_model_key_index(conn_data) - self.assigned_differentiabilities.append((model_info, value)) + self.assigned_differentiabilities[model_info] = value model = self._get_outermost_parent() model.constraint_solver(updates) @@ -1460,7 +1456,7 @@ def _set_shapes( **kwargs: ShapeTemplateType, ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. - assigned_shapes: list[tuple[tuple[BaseModel, int], ShapeTemplateType]] = [] + assigned_shapes: dict[tuple[BaseModel, int], ShapeTemplateType] = {} updates = Updates() if shapes is None: shapes = {} @@ -1482,7 +1478,7 @@ def _set_shapes( # In order to store assigned shapes, we need to store corresponding model # and index of the connection for that model. model_info = self.extract_model_key_index(conn) - assigned_shapes.append((model_info, shape)) + assigned_shapes[model_info] = shape # Apply updates to the shape nodes. for key in chain(shapes, kwargs): @@ -1533,7 +1529,7 @@ def _set_types( if trace: # Store assigned types in the model. model_info = self.extract_model_key_index(conn) - self.assigned_types.append((model_info, key_type)) + self.assigned_types[model_info] = key_type # Run the constraints for updating affected connections. model.constraint_solver(updates) diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 9630b57b..21b82c50 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -155,11 +155,10 @@ def _serialize_assigned_info( differentiability_info: list[tuple[tuple[str, int] | str, bool]] = [] constraints_info: list[AssignedConstraintType] = [] - # Shapes conversion. + # Shapes info. for shp_info in model.assigned_shapes: info_list: list[tuple[tuple[str, int] | str, ShapeTemplateType]] = [] - for sub_info in shp_info: - (m, key_index), shape_info = sub_info + for (m, key_index), shape_info in shp_info.items(): # Key info. key_info = _extract_key_info(model, submodel_dict, (m, key_index)) # Shape info. @@ -173,9 +172,8 @@ def _serialize_assigned_info( info_list.append((key_info, shape_list)) shapes_info.append(info_list) - # Types conversion. - for type_info in model.assigned_types: - (m, key_index), typ = type_info + # Types info. + for (m, key_index), typ in model.assigned_types.items(): # Key info. key_info = _extract_key_info(model, submodel_dict, (m, key_index)) # Combine key info with type info. @@ -184,9 +182,8 @@ def _serialize_assigned_info( elif typ is not ToBeDetermined: types_info.append((key_info, str(typ))) - # Differentiability settings for models and keys. - for diff_info in model.assigned_differentiabilities: - (m, key_index), status = diff_info + # Differentiability info. + for (m, key_index), status in model.assigned_differentiabilities.items(): # Key info. key_info = _extract_key_info(model, submodel_dict, (m, key_index)) # Combine key info with differentiability status. From 638a7b2993a3a0c8e789e204ed2b3e1c162757f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Wed, 5 Mar 2025 14:43:25 +0300 Subject: [PATCH 3/7] Tests added for multiple type setting for the same connection. --- mithril/utils/dict_conversions.py | 2 +- tests/scripts/test_model_to_dict_rtt.py | 77 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 21b82c50..b1a1583b 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -180,7 +180,7 @@ def _serialize_assigned_info( if get_origin(typ) is Tensor: types_info.append((key_info, "tensor")) elif typ is not ToBeDetermined: - types_info.append((key_info, str(typ))) + types_info.append((key_info, str(typ.__name__))) # type: ignore # Differentiability info. for (m, key_index), status in model.assigned_differentiabilities.items(): diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index 890924d7..e7063160 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -1086,3 +1086,80 @@ def test_assigned_types_2(): ("mean_input", "tensor"), ] ) + + +def test_assigned_types_multiple_times(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + mean_model = Mean(axis=TBD) + model |= mean_model(input="mean_input", output=IOKey(name="mean_out")) + model.set_types(mean_input=Tensor[int | float]) + model.set_types({mean_model.input: Tensor[int | float]}) + + outer_model = Model() + outer_model |= model + + # Assert only one assignment made even thought set multiple + # times. + assert len(model.assigned_types) == 1 + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [] + ) + + assert ( + model_dict_created["submodels"]["m_0"].get("assigned_types") # type: ignore + == model_dict_recreated["submodels"]["m_0"].get("assigned_types") # type: ignore + == [ + ("mean_input", "tensor"), + ] + ) + + +def test_assigned_types_multiple_times_different_types(): + model = Model() + buff_model = Buffer() + model |= buff_model(input="buff_input", output=IOKey(name="buff_out")) + mean_model = Mean(axis=TBD) + model |= mean_model(input="mean_input", output=IOKey(name="mean_out")) + # Set types for buff_model 2 times with different types. + # Note that the last assignment will be used. + model.set_types(buff_input=Tensor[int | float] | int | float) + model.set_types({buff_model.input: int}) + + outer_model = Model() + outer_model |= model + + # Assert only one assignment made even thought set multiple + # times. + assert len(model.assigned_types) == 1 + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [] + ) + + assert ( + model_dict_created["submodels"]["m_0"].get("assigned_types") # type: ignore + == model_dict_recreated["submodels"]["m_0"].get("assigned_types") # type: ignore + == [ + ("buff_input", "int"), + ] + ) From 7485b4208d46b3939a050345a68f2a295254e9ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Wed, 5 Mar 2025 15:16:50 +0300 Subject: [PATCH 4/7] json file updated for newly added tests on main branch. --- examples/flux/t5.py | 4 ++-- tests/json_files/models_directed_test.json | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/flux/t5.py b/examples/flux/t5.py index e4c78575..e8fbf58b 100644 --- a/examples/flux/t5.py +++ b/examples/flux/t5.py @@ -180,9 +180,9 @@ def load_t5_encoder( repo_id: str = "black-forest-labs/FLUX.1-schnell", max_len: int = 256, ) -> ml.models.PhysicalModel: - config = hf_hub_download(repo_id, "text_encoder_2/config.json") + config_path = hf_hub_download(repo_id, "text_encoder_2/config.json") - with open(config) as f: + with open(config_path) as f: config = json.load(f) t5 = t5_encode(config, name="encoder") diff --git a/tests/json_files/models_directed_test.json b/tests/json_files/models_directed_test.json index 14f9ee2d..539698f0 100644 --- a/tests/json_files/models_directed_test.json +++ b/tests/json_files/models_directed_test.json @@ -2256,7 +2256,7 @@ "test_maxpool1d_1_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 } @@ -2287,7 +2287,7 @@ "test_maxpool1d_2_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2316,7 +2316,7 @@ "test_maxpool1d_3_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2348,7 +2348,7 @@ "test_maxpool1d_4_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [0,1] @@ -2681,7 +2681,7 @@ "test_maxpool2d_1_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 }, @@ -2724,7 +2724,7 @@ "test_maxpool2d_2_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2761,7 +2761,7 @@ "test_maxpool2d_3_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2807,7 +2807,7 @@ "test_maxpool2d_4_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [[1,0],[0,1]] From d799007482282272b7356bdc5c5e236b4c952828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Thu, 6 Mar 2025 15:09:47 +0300 Subject: [PATCH 5/7] Assigned attributes holds Connection data instead of (model, index). Primitives are freezed at the end of instantiation. Bugs fixed. --- mithril/framework/logical/base.py | 107 +++++++++++-- mithril/framework/logical/operator.py | 4 +- mithril/framework/logical/operators.py | 20 +-- mithril/framework/logical/primitive.py | 1 + mithril/models/models.py | 80 +++++----- mithril/models/primitives.py | 6 +- mithril/utils/dict_conversions.py | 193 ++++++++++++++---------- tests/scripts/helper.py | 51 +++++-- tests/scripts/test_model_to_dict_rtt.py | 30 +++- tests/scripts/test_scripts.py | 4 +- 10 files changed, 333 insertions(+), 163 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 81c40d8b..f819ad71 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -374,13 +374,17 @@ def __init__( self._formula_key: str | None = formula_key # TODO: maybe set it only to Operator / Model. self.parent: BaseModel | None = None - self.assigned_shapes: list[dict[tuple[BaseModel, int], ShapeTemplateType]] = [] + self.assigned_shapes: list[dict[ConnectionData, ShapeTemplateType]] = [] self.assigned_types: dict[ - tuple[BaseModel, int], + ConnectionData, type | UnionType | ScalarType | type[Tensor[int | float | bool]], ] = {} - self.assigned_differentiabilities: dict[tuple[BaseModel, int], bool] = {} + self.assigned_differentiabilities: dict[ConnectionData, bool] = {} self.assigned_constraints: list[AssignedConstraintType] = [] + self.assigned_canonicals: dict[str, set[ConnectionData]] = { + "cins": set(), + "couts": set(), + } self.conns = Connections() self.frozen_attributes: list[str] = [] self.dependency_map = DependencyMap(self.conns) @@ -541,7 +545,10 @@ def _add_connection( given_connection = self._create_connection(edge, None) assert isinstance(given_connection, ConnectionData) - if given_connection.metadata.differentiable != edge.differentiable: + if ( + given_connection.metadata.differentiable is not None + and given_connection.metadata.differentiable != edge.differentiable + ): set_diff = given_connection.metadata.differentiable # Connection is given as a Connection object. @@ -791,6 +798,9 @@ def _merge_connections( # TODO: Deleted connection's 'key' attribute is not updated. # Consider updating it. + # Update assigned attributes with conn2 to conn1. + self._update_assigned_attributes(conn1, conn2) + return updates def extend( @@ -1379,6 +1389,23 @@ def _create_connection( def extract_model_key_index( self, connection: ConnectionData ) -> tuple[BaseModel, int]: + """ + Extracts the model and key index for a given connection. + + This method determines the model and the index of the key associated with + that model. It handles both frozen and non-frozen states of the model. + + Args: + connection (ConnectionData): The connection data for which the model and + key index need to be extracted. + + Returns: + tuple[BaseModel, int]: A tuple containing the model and the index + of the key for that model. + + Raises: + KeyError: If the connection is not found in the model. + """ if connection not in self.conns.all.values(): raise KeyError(f"Connection {connection.key} is not found in the model") @@ -1390,10 +1417,8 @@ def extract_model_key_index( else: assert connection.model is not None key = connection.key - if ( - _model := connection.model - ) is self and connection.key not in self.external_keys: - # Always store info using freezed models. + if (_model := connection.model) is self: + # Always return info using freezed models. model_info: ( list[tuple[BaseModel, OrderedSet[ConnectionData]]] | tuple[BaseModel, OrderedSet[ConnectionData]] @@ -1437,8 +1462,7 @@ def _set_differentiability( updates |= conn_data.set_differentiability(value) if trace: - model_info = self.extract_model_key_index(conn_data) - self.assigned_differentiabilities[model_info] = value + self.assigned_differentiabilities[conn_data] = value model = self._get_outermost_parent() model.constraint_solver(updates) @@ -1456,7 +1480,7 @@ def _set_shapes( **kwargs: ShapeTemplateType, ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. - assigned_shapes: dict[tuple[BaseModel, int], ShapeTemplateType] = {} + assigned_shapes: dict[ConnectionData, ShapeTemplateType] = {} updates = Updates() if shapes is None: shapes = {} @@ -1477,8 +1501,7 @@ def _set_shapes( shape_nodes[key] = (given_repr.node, inner_key) # In order to store assigned shapes, we need to store corresponding model # and index of the connection for that model. - model_info = self.extract_model_key_index(conn) - assigned_shapes[model_info] = shape + assigned_shapes[conn] = shape # Apply updates to the shape nodes. for key in chain(shapes, kwargs): @@ -1528,8 +1551,9 @@ def _set_types( updates |= metadata.set_type(key_type) if trace: # Store assigned types in the model. - model_info = self.extract_model_key_index(conn) - self.assigned_types[model_info] = key_type + if key_type is Tensor: + key_type = Tensor[int | float | bool] + self.assigned_types[conn] = key_type # Run the constraints for updating affected connections. model.constraint_solver(updates) @@ -1685,6 +1709,11 @@ def cout(self) -> ConnectionData: return next(iter(self.conns.couts)) def set_cin(self, *connections: str | ConnectionData, safe: bool = True) -> None: + self._set_cin(*connections, safe=safe, trace=True) + + def _set_cin( + self, *connections: str | ConnectionData, safe: bool = True, trace: bool = False + ) -> None: self.conns.cins = set() for given_conn in connections: conn = self.conns.get_extracted_connection(given_conn) @@ -1703,8 +1732,15 @@ def set_cin(self, *connections: str | ConnectionData, safe: bool = True) -> None ) else: self.conns.cins.add(conn) + if trace: + self.assigned_canonicals["cins"].add(conn) def set_cout(self, *connections: str | ConnectionData, safe: bool = True) -> None: + self._set_cout(*connections, safe=safe, trace=True) + + def _set_cout( + self, *connections: str | ConnectionData, safe: bool = True, trace: bool = False + ) -> None: self.conns.couts = set() for given_conn in connections: conn = self.conns.get_extracted_connection(given_conn) @@ -1717,6 +1753,8 @@ def set_cout(self, *connections: str | ConnectionData, safe: bool = True) -> Non ) else: self.conns.couts.add(conn) + if trace: + self.assigned_canonicals["couts"].add(conn) def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: l_type = left.edge_type @@ -1753,6 +1791,45 @@ def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: # Match data of each IOHyperEdge's. return left.match(right) + def _update_assigned_attributes( + self, new: ConnectionData, old: ConnectionData + ) -> None: + """ + Update assigned attributes by replacing occurrences of the old ConnectionData + with the new ConnectionData. + + This method updates the following attributes: + - assigned_shapes: Replaces old ConnectionData with new ConnectionData + in the assigned shapes. + - assigned_types: Replaces old ConnectionData with new ConnectionData + in the assigned types. + - assigned_differentiabilities: Replaces old ConnectionData with new + ConnectionData in the assigned differentiabilities. + - assigned_canonicals: Updates the 'cins' and 'couts' sets by removing the + old ConnectionData and adding the new ConnectionData. + + Args: + new (ConnectionData): The new ConnectionData to replace the old one. + old (ConnectionData): The old ConnectionData to be replaced. + """ + for shape_info in self.assigned_shapes: + if old in shape_info: + shape_info[new] = shape_info.pop(old) + if old in self.assigned_types: + self.assigned_types[new] = self.assigned_types.pop(old) + + if old in self.assigned_differentiabilities: + self.assigned_differentiabilities[new] = ( + self.assigned_differentiabilities.pop(old) + ) + + if old in self.assigned_canonicals["cins"]: + self.assigned_canonicals["cins"].remove(old) + self.assigned_canonicals["cins"].add(new) + elif old in self.assigned_canonicals["couts"]: + self.assigned_canonicals["couts"].remove(old) + self.assigned_canonicals["couts"].add(new) + class DependencyMap: """ diff --git a/mithril/framework/logical/operator.py b/mithril/framework/logical/operator.py index 02d70a21..4a964696 100644 --- a/mithril/framework/logical/operator.py +++ b/mithril/framework/logical/operator.py @@ -106,7 +106,7 @@ def __init__( ) canonical_input_conn = self.conns.get_connection(canonical_input_key) if canonical_input_conn is not None: - self.set_cin(canonical_input_conn, safe=False) + self._set_cin(canonical_input_conn, safe=False) canonical_output_key = ( "output" @@ -115,7 +115,7 @@ def __init__( ) canonical_output_conn = self.conns.get_connection(canonical_output_key) if canonical_output_conn is not None: - self.set_cout(canonical_output_conn, safe=False) + self._set_cout(canonical_output_conn, safe=False) self._freeze() @property diff --git a/mithril/framework/logical/operators.py b/mithril/framework/logical/operators.py index e2480459..c3f885b5 100644 --- a/mithril/framework/logical/operators.py +++ b/mithril/framework/logical/operators.py @@ -163,7 +163,7 @@ def __init__( fn=to_tuple_constraints, keys=[Operator.output_key] + [key for key in self.input_keys], ) - self.set_cin() + self._set_cin() class PowerOp(Operator): @@ -301,7 +301,7 @@ def __init__( types=[UpdateType.TYPE], dependencies={edge_constraint}, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class SubtractOp(Operator): @@ -410,7 +410,7 @@ def __init__( types=[UpdateType.TYPE], dependencies={edge_constraint}, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class MinimumOp(Operator): @@ -451,7 +451,7 @@ def __init__( keys=[Operator.output_key, "left", "right"], types=[UpdateType.TYPE], ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class MaximumOp(Operator): @@ -492,7 +492,7 @@ def __init__( keys=[Operator.output_key, "left", "right"], types=[UpdateType.TYPE], ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class DivideOp(Operator): @@ -851,7 +851,7 @@ def __init__( fn=to_list_constraints, keys=[Operator.output_key] + [key for key in self.input_keys], ) - self.set_cin() + self._set_cin() class TensorToListOp(Operator): @@ -1298,7 +1298,7 @@ def __init__( super().__init__( formula_key="equal", name=name, operator=operator.eq, left=left, right=right ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class NotEqualOp(RelationalOperatorsOp): @@ -1318,7 +1318,7 @@ def __init__( left=left, right=right, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class LessEqualOp(RelationalOperatorsOp): @@ -1428,7 +1428,7 @@ def __init__( keys=[Operator.output_key, "left", "right"], dependencies={edge_constraint}, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class LogicalAndOp(BitwiseOperatorsOp): @@ -1684,7 +1684,7 @@ def __init__( self._add_constraint( fn=slice_constraints, keys=["output", "start", "stop", "step"] ) - self.set_cin() + self._set_cin() class IndexerOp(Operator): diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 20ab841c..6eb7e51b 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -32,6 +32,7 @@ def __init__( super().__init__(name=name, enforce_jit=model._jittable) self._extend(model, {k: k for k in model.external_keys}, trace=False) self.expose_keys(*model.external_keys) + self._freeze() @property def submodel(self) -> Operator: diff --git a/mithril/models/models.py b/mithril/models/models.py index fd1959eb..6b99fe9f 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -200,7 +200,7 @@ def __init__( output=IOKey(name="output"), ) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -301,7 +301,7 @@ def __init__( dilation=dt_converter.output, output=IOKey(name="output"), ) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -397,7 +397,7 @@ def __init__( conv_connections["bias"] = IOKey("bias", differentiable=True) self |= PrimitiveConvolution1D(use_bias=use_bias)(**conv_connections) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -491,7 +491,7 @@ def __init__( conv_connections["bias"] = IOKey("bias", differentiable=True) self |= PrimitiveConvolution2D(use_bias=use_bias)(**conv_connections) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -558,7 +558,7 @@ def __init__( self |= mult(left=input_key, right=weight_key, output=output) self._set_shapes(**shapes) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -605,7 +605,7 @@ def __init__( right=IOKey(name="bias", value=bias), output=IOKey(name="output"), ) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -648,7 +648,7 @@ def __init__( bias=IOKey("bias", value=bias), ) self += activation(output=IOKey(name="output")) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -718,7 +718,7 @@ def __init__( add._set_shapes(**shapes) # TODO: Remove below Buffer after required naming-related changes are done. self |= Buffer()(input=self.cout, output=IOKey(name="output")) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -799,7 +799,7 @@ def __init__( add._set_shapes(**shapes) self |= Buffer()(input=self.cout, output=IOKey(name="output")) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -842,8 +842,8 @@ def __init__( self |= abs_model(input=IOKey("input", value=input)) self += Sum()(output=IOKey(name="output")) - self.set_cin("input", safe=False) - self.set_cout("output", safe=False) + self._set_cin("input", safe=False) + self._set_cout("output", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -870,8 +870,8 @@ def __init__( self += Square()(input=IOKey("input", value=input)) self += Sum() self += Multiply()(right=Tensor(0.5), output=IOKey(name="output")) - self.set_cin("input", safe=False) - self.set_cout("output", safe=False) + self._set_cin("input", safe=False) + self._set_cout("output", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -910,7 +910,7 @@ def __init__( ) shapes: dict[str, ShapeTemplateType] = {"input": [1, "N"], "kernel": ["N", "N"]} self._set_shapes(**shapes) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -983,7 +983,7 @@ def __init__( } self._set_shapes(**shapes) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -1044,7 +1044,7 @@ def __init__( fn=polynomial_kernel_constraint, keys=["poly_coef", "degree"], ) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -1118,7 +1118,7 @@ def __init__( "kernel": ["N", "M"], } self._set_shapes(**shapes) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -1166,7 +1166,7 @@ def __init__( ) self += decision_model(output=IOKey(name="decision_output")) - self.set_cout(linear_model.output) + self._set_cout(linear_model.output) self._freeze() def __call__( # type: ignore[override] @@ -1216,7 +1216,7 @@ def __init__( input=linear_model.output, output=IOKey(name="probs_output") ) - self.set_cout(linear_model.output) + self._set_cout(linear_model.output) self._freeze() def __call__( # type: ignore[override] @@ -1298,7 +1298,7 @@ def __init__( # Add current layer to the model. self += current_layer(**kwargs) prev_layer = current_layer - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( @@ -1417,8 +1417,8 @@ def __init__( } self._set_shapes(**shapes) - self.set_cin("input", safe=False) - self.set_cout("output") + self._set_cin("input", safe=False) + self._set_cout("output") self._freeze() def __call__( # type: ignore[override] @@ -1592,8 +1592,8 @@ def __init__( } self._set_shapes(**shapes) - self.set_cin("input", safe=False) - self.set_cout("output") + self._set_cin("input", safe=False) + self._set_cout("output") self._freeze() def __call__( # type: ignore[override] @@ -1872,8 +1872,8 @@ def __init__( ) prev_cell = current_cell - self.set_cin("input") - self.set_cout(current_cell.output) + self._set_cin("input") + self._set_cout(current_cell.output) self._freeze() def __call__( @@ -1998,8 +1998,8 @@ def __init__( input=concat_input_args, output=IOKey(name="hidden_concat", value=hidden_concat), ) - self.set_cin("input0") - self.set_cout("hidden_concat") + self._set_cin("input0") + self._set_cout("hidden_concat") self._freeze() def __call__( @@ -2055,7 +2055,7 @@ def __init__( initial_hidden=permutation_model.output, **(dec_input_mapping | dec_output_mapping), ) - self.set_cout(decoder.cout) + self._set_cout(decoder.cout) self._freeze() @@ -2103,7 +2103,7 @@ def __init__( initial_hidden=encoder.hidden_concat, **(dec_input_mapping | dec_output_mapping), ) - self.set_cout(decoder.cout) + self._set_cout(decoder.cout) self._freeze() def __call__(self, **model_keys: ConnectionType) -> ExtendInfo: @@ -2157,7 +2157,7 @@ def __init__( norm=modifier_model.output, output=IOKey(name="output"), ) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -2381,8 +2381,8 @@ def __init__( ) self._set_shapes(distances=["N", "N"], pred_distances=["N", "N"]) - self.set_cin("distances", safe=False) - self.set_cout("output") + self._set_cin("distances", safe=False) + self._set_cout("output") self._freeze() def __call__( # type: ignore[override] @@ -2719,7 +2719,7 @@ def __init__( output=IOKey(name="confidence", value=confidence), ) - self.set_cout(pred_model.output) + self._set_cout(pred_model.output) shapes: dict[str, ShapeTemplateType] = { "label": ["N", 1], "s": [1], @@ -2900,7 +2900,7 @@ def __init__( self |= Buffer()(input=label_key, output=IOKey("label_formatted")) self |= Buffer()(input=result, output=IOKey("output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -2962,7 +2962,7 @@ def __init__( denominator=n_prediction.tensor(), output=IOKey(name="output"), ) - self.set_cin(self.pred) + self._set_cin(self.pred) def __call__( # type: ignore[override] self, @@ -3121,7 +3121,7 @@ def __init__( self |= Buffer()(input=precision, output=IOKey(name="output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -3280,7 +3280,7 @@ def __init__( self |= Buffer()(input=recall, output=IOKey(name="output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -3444,7 +3444,7 @@ def __init__( self |= Buffer()(input=precision, output=IOKey(name="output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -3505,7 +3505,7 @@ def __init__( self |= Buffer()(auc_score, IOKey("output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index d7beb19f..10eb2eb7 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -1748,7 +1748,7 @@ def __init__( fn=general_tensor_type_constraint, keys=[Operator.output_key, "left", "right", "norm"], ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) def __call__( # type: ignore[override] self, @@ -2068,7 +2068,7 @@ def __init__( step=BaseKey(type=int | float, value=step), dtype=BaseKey(type=types.Dtype | None, value=dtype), ) - self.set_cin("stop", safe=False) + self._set_cin("stop", safe=False) if not all_defined: self._add_constraint( @@ -2493,7 +2493,7 @@ def __init__( fn=general_tensor_type_constraint, keys=[Operator.output_key, "input1", "input2"], ) - self.set_cin("input1", safe=False) + self._set_cin("input1", safe=False) def __call__( # type: ignore[override] self, diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index b1a1583b..bd289879 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -84,12 +84,12 @@ class ModelDict(TypedDict, total=False): assigned_types: list[tuple[tuple[str, int] | str, str]] assigned_differentiabilities: list[tuple[tuple[str, int] | str, bool]] assigned_constraints: list[AssignedConstraintType] + assigned_canonicals: dict[str, list[tuple[str, int] | str]] tuples: list[str] enums: dict[str, str] unnamed_keys: list[str] submodels: dict[str, ModelDict] connections: dict[str, dict[str, str | ConnectionDict]] - canonical_keys: dict[str, tuple[set[str], set[str]]] class TrainModelDict(TypedDict): @@ -126,9 +126,9 @@ class TrainModelDict(TypedDict): def _extract_key_info( - model: BaseModel, submodel_dict: dict[BaseModel, str], info: tuple[BaseModel, int] + model: BaseModel, submodel_dict: dict[BaseModel, str], conn: ConnectionData ) -> tuple[str, int] | str: - m, key_index = info + m, key_index = model.extract_model_key_index(conn) m_name = submodel_dict[m] if m is not model else "self" key_name = list(model.external_keys)[key_index] @@ -149,6 +149,7 @@ def _serialize_assigned_info( list[tuple[tuple[str, int] | str, str]], list[tuple[tuple[str, int] | str, bool]], list[AssignedConstraintType], + dict[str, list[tuple[str, int] | str]], ]: shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]] = [] types_info: list[tuple[tuple[str, int] | str, str]] = [] @@ -158,9 +159,9 @@ def _serialize_assigned_info( # Shapes info. for shp_info in model.assigned_shapes: info_list: list[tuple[tuple[str, int] | str, ShapeTemplateType]] = [] - for (m, key_index), shape_info in shp_info.items(): + for conn, shape_info in shp_info.items(): # Key info. - key_info = _extract_key_info(model, submodel_dict, (m, key_index)) + key_info = _extract_key_info(model, submodel_dict, conn) # Shape info. shape_list: list[int | str | tuple[str, EllipsisType] | None] = [] for item in shape_info: @@ -173,9 +174,9 @@ def _serialize_assigned_info( shapes_info.append(info_list) # Types info. - for (m, key_index), typ in model.assigned_types.items(): + for conn, typ in model.assigned_types.items(): # Key info. - key_info = _extract_key_info(model, submodel_dict, (m, key_index)) + key_info = _extract_key_info(model, submodel_dict, conn) # Combine key info with type info. if get_origin(typ) is Tensor: types_info.append((key_info, "tensor")) @@ -183,9 +184,9 @@ def _serialize_assigned_info( types_info.append((key_info, str(typ.__name__))) # type: ignore # Differentiability info. - for (m, key_index), status in model.assigned_differentiabilities.items(): + for conn, status in model.assigned_differentiabilities.items(): # Key info. - key_info = _extract_key_info(model, submodel_dict, (m, key_index)) + key_info = _extract_key_info(model, submodel_dict, conn) # Combine key info with differentiability status. differentiability_info.append((key_info, status)) @@ -193,7 +194,49 @@ def _serialize_assigned_info( for constraint in model.assigned_constraints: constraints_info.append(constraint) - return shapes_info, types_info, differentiability_info, constraints_info + # Canonical keys. + assigned_canonicals: dict[str, list[tuple[str, int] | str]] = { + "cins": [], + "couts": [], + } + for conn in model.assigned_canonicals["cins"]: + key_info = _extract_key_info(model, submodel_dict, conn) + assigned_canonicals["cins"].append(key_info) + for conn in model.assigned_canonicals["couts"]: + key_info = _extract_key_info(model, submodel_dict, conn) + assigned_canonicals["couts"].append(key_info) + # Sort cins and couts in order to get exactly the same json for + # the same model. + assigned_canonicals["cins"].sort() + assigned_canonicals["couts"].sort() + + return ( + shapes_info, + types_info, + differentiability_info, + constraints_info, + assigned_canonicals, + ) + + +def _construct_info[T]( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + config: dict[ConnectionData, T], + kwargs: dict[str, T], + info: tuple[tuple[str, int] | str, T], +) -> None: + info_key, data = info + if isinstance(info_key, tuple): + model_name, index = info_key + connection = _extract_connection_from_index( + model, model_name, index, submodel_dict + ) + config[connection] = data + elif isinstance(info_key, str): + kwargs[info_key] = data + else: + raise RuntimeError("Unknown info format!") def _extract_connection_from_index( @@ -210,70 +253,55 @@ def _set_assigned_info( model: BaseModel, submodel_dict: dict[str, BaseModel], shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]], - types_info: list[tuple[tuple[str, int] | str, str]], + types_info: list[tuple[tuple[str, int] | str, type]], diffs_info: list[tuple[tuple[str, int] | str, bool]], constraints_info: list[AssignedConstraintType], + canonicals_info: dict[str, list[tuple[str, int] | str]], ) -> None: # Shapes conversion. for shp_info in shapes_info: shapes: dict[ConnectionData, ShapeTemplateType] = {} shape_kwargs: dict[str, ShapeTemplateType] = {} for sub_info in shp_info: - shape = sub_info[1] - if isinstance(sub_info[0], tuple): - model_name, index = sub_info[0] - connection = _extract_connection_from_index( - model, model_name, index, submodel_dict - ) - shapes[connection] = shape - elif isinstance(sub_info[0], str): - name = sub_info[0] - shape_kwargs[name] = shape - else: - raise RuntimeError("Unknown shape info format") + _construct_info(model, submodel_dict, shapes, shape_kwargs, sub_info) model.set_shapes(shapes, **shape_kwargs) # Types conversion. - types_config = {} - types_kwargs = {} + types_config: dict[ConnectionData, type] = {} + types_kwargs: dict[str, type] = {} for type_info in types_info: - typ: type - if type_info[1] == "tensor": - typ = Tensor[int | float | bool] - else: - typ = eval(type_info[1]) - - if isinstance(type_info[0], tuple): - model_name, index = type_info[0] - connection = _extract_connection_from_index( - model, model_name, index, submodel_dict - ) - types_config[connection] = typ - elif isinstance(type_info[0], str): - name = type_info[0] - types_kwargs[name] = typ - else: - raise RuntimeError("Unknown type info format!") - + _construct_info(model, submodel_dict, types_config, types_kwargs, type_info) + # if isinstance(type_info[0], tuple): + # model_name, index = type_info[0] + # connection = _extract_connection_from_index( + # model, model_name, index, submodel_dict + # ) + # types_config[connection] = typ + # elif isinstance(type_info[0], str): + # name = type_info[0] + # types_kwargs[name] = typ + # else: + # raise RuntimeError("Unknown type info format!") model.set_types(types_config, **types_kwargs) # Differentiability settings for models and keys. diff_config: dict[ConnectionData, bool] = {} diff_kwargs: dict[str, bool] = {} for diff_info in diffs_info: - status = diff_info[1] - if isinstance(diff_info[0], tuple): - model_name, index = diff_info[0] - connection = _extract_connection_from_index( - model, model_name, index, submodel_dict - ) - diff_config[connection] = status - elif isinstance(diff_info[0], str): - name = diff_info[0] - diff_kwargs[name] = status - else: - raise RuntimeError("Unknown differentiability info format!") - model.set_differentiability(diff_config, **diff_kwargs) + _construct_info(model, submodel_dict, diff_config, diff_kwargs, diff_info) + # status = diff_info[1] + # if isinstance(diff_info[0], tuple): + # model_name, index = diff_info[0] + # connection = _extract_connection_from_index( + # model, model_name, index, submodel_dict + # ) + # diff_config[connection] = status + # elif isinstance(diff_info[0], str): + # name = diff_info[0] + # diff_kwargs[name] = status + # else: + # raise RuntimeError("Unknown differentiability info format!") + model.set_differentiability(diff_config, **diff_kwargs) # Constraints. for constr_info in constraints_info: @@ -286,6 +314,30 @@ def _set_assigned_info( constrain_fn = constrain_fn_dict[constrain_fn] model.add_constraint(constrain_fn, keys=constr_info["keys"]) # type: ignore + # Canonical keys + cins: set[ConnectionData | str] = set() + couts: set[ConnectionData | str] = set() + for canonical_type, info in canonicals_info.items(): + related_set = cins if canonical_type == "cins" else couts + for item in info: + if canonical_type not in ["cins", "couts"]: + raise RuntimeError("Unknown canonical key type!") + + if isinstance(item, tuple): + model_name, index = item + connection = _extract_connection_from_index( + model, model_name, index, submodel_dict + ) + related_set.add(connection) + elif isinstance(item, str): + related_set.add(item) + else: + raise RuntimeError("Unknown canonical key info format!") + if cins: + model.set_cin(*cins) + if couts: + model.set_cout(*couts) + def create_iokey_kwargs( info: KeyDict, submodels_dict: dict[str, BaseModel] @@ -370,9 +422,6 @@ def dict_to_model( model = type(model_name, (Model,), attrs)(**args) unnamed_keys: list[str] = params.get("unnamed_keys", []) - canonical_keys: dict[str, tuple[set[str], set[str]]] = params.get( - "canonical_keys", {} - ) assert isinstance(model, Model) submodels_dict: dict[str, BaseModel] = {} @@ -406,17 +455,15 @@ def dict_to_model( con = getattr(m, key) model.merge_connections(con, *conns) - if "model" in canonical_keys: - if canonical_keys["model"][0] is not None: - model.set_cin(*canonical_keys["model"][0]) - if canonical_keys["model"][1] is not None: - model.set_cout(*canonical_keys["model"][1]) - # Set all assigned info. - assigned_types = params.get("assigned_types", []) + assigned_types = [ + (info, Tensor if typ == "tensor" else eval(typ)) + for info, typ in params.get("assigned_types", []) + ] assigned_shapes = params.get("assigned_shapes", []) assigned_differentiabilities = params.get("assigned_differentiabilities", []) assigned_constraints = params.get("assigned_constraints", []) + assigned_canonicals = params.get("assigned_canonicals", {}) _set_assigned_info( model, submodels_dict, @@ -424,6 +471,7 @@ def dict_to_model( assigned_types, assigned_differentiabilities, assigned_constraints, + assigned_canonicals, ) assert isinstance(model, Model) @@ -446,7 +494,6 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: if model_name == "Model" and model_name in dir(models): connection_dict: dict[str, dict[str, str | ConnectionDict]] = {} - canonical_keys: dict[str, tuple[set[str], set[str]]] = {} submodels: dict[str, ModelDict] = {} # IOHyperEdge -> [model_id, connection_name] @@ -467,18 +514,9 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: connection_dict[model_id] = connection_to_dict( model, submodel, submodel_connections, model_id ) - canonical_keys[model_id] = ( - get_keys(submodel.conns.cins), - get_keys(submodel.conns.couts), - ) - canonical_keys["model"] = ( - get_keys(model.conns.cins), - get_keys(model.conns.couts), - ) model_dict |= { "connections": connection_dict, - "canonical_keys": canonical_keys, "submodels": submodels, } @@ -487,11 +525,14 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: assigned_types, assigned_differentiabilities, assigned_constraints, + assigned_canonicals, ) = _serialize_assigned_info(model, submodel_obj_dict) + model_dict |= {"assigned_shapes": assigned_shapes} model_dict |= {"assigned_types": assigned_types} model_dict |= {"assigned_differentiabilities": assigned_differentiabilities} model_dict |= {"assigned_constraints": assigned_constraints} + model_dict |= {"assigned_canonicals": assigned_canonicals} return model_dict diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 516f3088..3ad5703a 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -210,19 +210,44 @@ def assert_models_equal(model1: BaseModel, model2: BaseModel): model1_keys = model1.generate_keys() model2_keys = model2.generate_keys() - # if model1.cin is not None and model2.cin is not None: - # assert model1_keys.get( - # key := model1.cin.key, key - # ) == model2_keys.get(key := model2.cin.key, key) - # assert model1_keys.get( - # key := model1.cout.key, key - # ) == model2_keys.get(key := model2.cout.key, key) - assert {model1_keys.get(con.key, con.key) for con in model2.conns.cins} == { - model2_keys.get(con.key, con.key) for con in model2.conns.cins - } - assert {model1_keys.get(con.key, con.key) for con in model2.conns.couts} == { - model2_keys.get(con.key, con.key) for con in model2.conns.couts - } + # For exact match, we need to compare the keys together with + # model and key_index of corresponding key for the corresponding model. + # Model info is converted also to index, which is the index of the model + # in the DAG, if extracted model is not the model itself. If the extracted + # model is the model itself, then the index is "self". + + # Canonical input keys tests. + model1_cins = set() + model2_cins = set() + for conn in model1.conns.cins: + model, key_index = model1.extract_model_key_index(conn) + model_index = ( + list(model1.dag.keys()).index(model) if model != model1 else "self" + ) + model1_cins.add((model_index, key_index)) + for conn in model2.conns.cins: + model, key_index = model2.extract_model_key_index(conn) + model_index = ( + list(model2.dag.keys()).index(model) if model != model2 else "self" + ) + model2_cins.add((model_index, key_index)) + assert model1_cins == model2_cins + # Canonical output keys tests. + model1_couts = set() + model2_couts = set() + for conn in model1.conns.couts: + model, key_index = model1.extract_model_key_index(conn) + model_index = ( + list(model1.dag.keys()).index(model) if model != model1 else "self" + ) + model1_couts.add((model_index, key_index)) + for conn in model2.conns.couts: + model, key_index = model2.extract_model_key_index(conn) + model_index = ( + list(model2.dag.keys()).index(model) if model != model2 else "self" + ) + model2_couts.add((model_index, key_index)) + assert model1_couts == model2_couts # NOTE: Below assertions will be uncommented after converting # model's dag from topological order to insertion order. diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index e7063160..6b86076b 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -1030,7 +1030,7 @@ def test_assigned_shapes(): model_dict_created.get("assigned_shapes") == model_dict_recreated.get("assigned_shapes") == [ - [("buff_input", [1, 2, "V,..."]), ("mean_input", ["V,...", 3, 4])], + [(("m_0", 0), [1, 2, "V,..."]), (("m_1", 0), ["V,...", 3, 4])], ] ) @@ -1052,7 +1052,7 @@ def test_assigned_types_1(): model_dict_created.get("assigned_types") == model_dict_recreated.get("assigned_types") == [ - ("mean_input", "tensor"), + (("m_1", 0), "tensor"), ] ) @@ -1163,3 +1163,29 @@ def test_assigned_types_multiple_times_different_types(): ("buff_input", "int"), ] ) + + +def test_assigned_types_from_outermost_model(): + model = Model() + buff_model = Buffer() + model |= buff_model(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + + outer_model = Model() + outer_model |= model + outer_model |= Buffer()(input="buff_input_2") + outer_model.set_types(buff_input_2=Tensor[int | float]) + outer_model.merge_connections(buff_model.input, "buff_input_2") + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [(("m_0", 0), "tensor")] + ) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 9db4999d..e26b2cc2 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -6973,7 +6973,7 @@ def test_extending_operator(): def test_extending_operator_model(): model1 = Buffer() - with pytest.raises(RuntimeError) as err: + with pytest.raises(AttributeError) as err: model1 += Buffer() - assert str(err.value) == "Primitive models cannot have submodels." + assert str(err.value) == "Model is frozen and can not be extended!" From a674bcff44b8aa2630357d767b20e310b41846e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Fri, 7 Mar 2025 09:40:12 +0300 Subject: [PATCH 6/7] Unnecessary comments removed. --- mithril/utils/dict_conversions.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index bd289879..db9471da 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -271,17 +271,6 @@ def _set_assigned_info( types_kwargs: dict[str, type] = {} for type_info in types_info: _construct_info(model, submodel_dict, types_config, types_kwargs, type_info) - # if isinstance(type_info[0], tuple): - # model_name, index = type_info[0] - # connection = _extract_connection_from_index( - # model, model_name, index, submodel_dict - # ) - # types_config[connection] = typ - # elif isinstance(type_info[0], str): - # name = type_info[0] - # types_kwargs[name] = typ - # else: - # raise RuntimeError("Unknown type info format!") model.set_types(types_config, **types_kwargs) # Differentiability settings for models and keys. @@ -289,18 +278,6 @@ def _set_assigned_info( diff_kwargs: dict[str, bool] = {} for diff_info in diffs_info: _construct_info(model, submodel_dict, diff_config, diff_kwargs, diff_info) - # status = diff_info[1] - # if isinstance(diff_info[0], tuple): - # model_name, index = diff_info[0] - # connection = _extract_connection_from_index( - # model, model_name, index, submodel_dict - # ) - # diff_config[connection] = status - # elif isinstance(diff_info[0], str): - # name = diff_info[0] - # diff_kwargs[name] = status - # else: - # raise RuntimeError("Unknown differentiability info format!") model.set_differentiability(diff_config, **diff_kwargs) # Constraints. From 9b80cb44762754ff4c00d409673d14cd47bfc980 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Fri, 7 Mar 2025 18:07:42 +0300 Subject: [PATCH 7/7] Resolved reviews --- mithril/framework/logical/base.py | 90 +++----------- mithril/utils/dict_conversions.py | 159 +++++++++++++++++------- tests/scripts/helper.py | 14 ++- tests/scripts/test_model_to_dict_rtt.py | 51 ++++++++ 4 files changed, 192 insertions(+), 122 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index f819ad71..193beca3 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -381,10 +381,8 @@ def __init__( ] = {} self.assigned_differentiabilities: dict[ConnectionData, bool] = {} self.assigned_constraints: list[AssignedConstraintType] = [] - self.assigned_canonicals: dict[str, set[ConnectionData]] = { - "cins": set(), - "couts": set(), - } + self.assigned_cins: set[ConnectionData] = set() + self.assigned_couts: set[ConnectionData] = set() self.conns = Connections() self.frozen_attributes: list[str] = [] self.dependency_map = DependencyMap(self.conns) @@ -871,13 +869,6 @@ def extend( submodel_dag[key].key: template for key, template in shape_info.items() } - # # Set given shapes. - # self._set_shapes(**shape_info) # TODO: Should "trace" be set to True?. - # self.constraint_solver(updates) - - # model.constraint_solver.clear() - # model.conns.connections_dict = {} - # Insert to self dag as a FrozenDict."" # Since we update dag in merge_connections, we could not use FrozenDict. self.dag[model] = model_dag = submodel_dag @@ -1386,57 +1377,6 @@ def _create_connection( connection.metadata = metadata return connection - def extract_model_key_index( - self, connection: ConnectionData - ) -> tuple[BaseModel, int]: - """ - Extracts the model and key index for a given connection. - - This method determines the model and the index of the key associated with - that model. It handles both frozen and non-frozen states of the model. - - Args: - connection (ConnectionData): The connection data for which the model and - key index need to be extracted. - - Returns: - tuple[BaseModel, int]: A tuple containing the model and the index - of the key for that model. - - Raises: - KeyError: If the connection is not found in the model. - """ - if connection not in self.conns.all.values(): - raise KeyError(f"Connection {connection.key} is not found in the model") - - if self.is_frozen: - _model = self - name = connection._name - assert name is not None - key_index = list(self.external_keys).index(name) - else: - assert connection.model is not None - key = connection.key - if (_model := connection.model) is self: - # Always return info using freezed models. - model_info: ( - list[tuple[BaseModel, OrderedSet[ConnectionData]]] - | tuple[BaseModel, OrderedSet[ConnectionData]] - ) = self.dependency_map.local_input_dependency_map.get( - connection, - self.dependency_map.local_output_dependency_map.get(connection), # type: ignore - ) - _model = ( - model_info[0][0] if isinstance(model_info, list) else model_info[0] - ) - # Get corresponding connection of the _model. - _model_conn = _model.conns.get_con_by_metadata(connection.metadata) - assert _model_conn is not None - key = _model_conn.key - key_index = list(_model.external_keys).index(key) - - return (_model, key_index) - def _set_differentiability( self, config: dict[ConnectionData, bool] | None = None, @@ -1658,6 +1598,7 @@ def _add_constraint( keys: list[str], types: list[UpdateType] | None = None, dependencies: set[Constraint] | None = None, + trace: bool = False, ) -> Constraint: all_conns = self.conns.all hyper_edges = [all_conns[key].metadata for key in keys] @@ -1677,6 +1618,10 @@ def _add_constraint( hyper_edge.add_constraint(constr) self.constraint_solver.solver_loop({constr}) + + if trace: + self.assigned_constraints.append({"fn": fn.__name__, "keys": keys}) + return constr def add_constraint( @@ -1686,8 +1631,7 @@ def add_constraint( type: list[UpdateType] | None = None, dependencies: set[Constraint] | None = None, ) -> Constraint: - self.assigned_constraints.append({"fn": fn.__name__, "keys": keys}) - return self._add_constraint(fn, keys, type, dependencies) + return self._add_constraint(fn, keys, type, dependencies, True) @property def cin(self) -> ConnectionData: @@ -1733,7 +1677,7 @@ def _set_cin( else: self.conns.cins.add(conn) if trace: - self.assigned_canonicals["cins"].add(conn) + self.assigned_cins.add(conn) def set_cout(self, *connections: str | ConnectionData, safe: bool = True) -> None: self._set_cout(*connections, safe=safe, trace=True) @@ -1754,7 +1698,7 @@ def _set_cout( else: self.conns.couts.add(conn) if trace: - self.assigned_canonicals["couts"].add(conn) + self.assigned_couts.add(conn) def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: l_type = left.edge_type @@ -1822,13 +1766,13 @@ def _update_assigned_attributes( self.assigned_differentiabilities[new] = ( self.assigned_differentiabilities.pop(old) ) - - if old in self.assigned_canonicals["cins"]: - self.assigned_canonicals["cins"].remove(old) - self.assigned_canonicals["cins"].add(new) - elif old in self.assigned_canonicals["couts"]: - self.assigned_canonicals["couts"].remove(old) - self.assigned_canonicals["couts"].add(new) + # Assigned canonicals + if old in self.assigned_cins: + self.assigned_cins.remove(old) + self.assigned_cins.add(new) + elif old in self.assigned_couts: + self.assigned_couts.remove(old) + self.assigned_couts.add(new) class DependencyMap: diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index db9471da..ec565d00 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -46,7 +46,7 @@ ) from ..models.train_model import TrainModel from ..utils import model_conversion_lut -from ..utils.utils import convert_to_tuple +from ..utils.utils import OrderedSet, convert_to_tuple class KeyDict(TypedDict, total=False): @@ -84,7 +84,8 @@ class ModelDict(TypedDict, total=False): assigned_types: list[tuple[tuple[str, int] | str, str]] assigned_differentiabilities: list[tuple[tuple[str, int] | str, bool]] assigned_constraints: list[AssignedConstraintType] - assigned_canonicals: dict[str, list[tuple[str, int] | str]] + assigned_cins: list[tuple[str, int] | str] + assigned_couts: list[tuple[str, int] | str] tuples: list[str] enums: dict[str, str] unnamed_keys: list[str] @@ -125,10 +126,60 @@ class TrainModelDict(TypedDict): enum_dict = {"PaddingType": PaddingType} +def extract_model_key_index( + model: BaseModel, connection: ConnectionData +) -> tuple[BaseModel, int]: + """ + Extracts the model and key index for a given connection. + + This function determines the model and the index of the key associated with + that model. It handles both frozen and non-frozen states of the model. + + Args: + connection (ConnectionData): The connection data for which the model and + key index need to be extracted. + + Returns: + tuple[BaseModel, int]: A tuple containing the model and the index + of the key for that model. + + Raises: + KeyError: If the connection is not found in the model. + """ + if connection not in model.conns.all.values(): + raise KeyError(f"Connection {connection.key} is not found in the model") + + if model.is_frozen: + _model = model + name = connection._name + assert name is not None + key_index = list(model.external_keys).index(name) + else: + assert connection.model is not None + key = connection.key + if (_model := connection.model) is model: + # Always return info using freezed models. + model_info: ( + list[tuple[BaseModel, OrderedSet[ConnectionData]]] + | tuple[BaseModel, OrderedSet[ConnectionData]] + ) = model.dependency_map.local_input_dependency_map.get( + connection, + model.dependency_map.local_output_dependency_map.get(connection), # type: ignore + ) + _model = model_info[0][0] if isinstance(model_info, list) else model_info[0] + # Get corresponding connection of the _model. + _model_conn = _model.conns.get_con_by_metadata(connection.metadata) + assert _model_conn is not None + key = _model_conn.key + key_index = list(_model.external_keys).index(key) + + return (_model, key_index) + + def _extract_key_info( model: BaseModel, submodel_dict: dict[BaseModel, str], conn: ConnectionData ) -> tuple[str, int] | str: - m, key_index = model.extract_model_key_index(conn) + m, key_index = extract_model_key_index(model, conn) m_name = submodel_dict[m] if m is not model else "self" key_name = list(model.external_keys)[key_index] @@ -149,7 +200,8 @@ def _serialize_assigned_info( list[tuple[tuple[str, int] | str, str]], list[tuple[tuple[str, int] | str, bool]], list[AssignedConstraintType], - dict[str, list[tuple[str, int] | str]], + list[tuple[str, int] | str], + list[tuple[str, int] | str], ]: shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]] = [] types_info: list[tuple[tuple[str, int] | str, str]] = [] @@ -195,31 +247,54 @@ def _serialize_assigned_info( constraints_info.append(constraint) # Canonical keys. - assigned_canonicals: dict[str, list[tuple[str, int] | str]] = { - "cins": [], - "couts": [], - } - for conn in model.assigned_canonicals["cins"]: + assigned_cins: list[tuple[str, int] | str] = [] + assigned_couts: list[tuple[str, int] | str] = [] + for conn in model.assigned_cins: key_info = _extract_key_info(model, submodel_dict, conn) - assigned_canonicals["cins"].append(key_info) - for conn in model.assigned_canonicals["couts"]: + assigned_cins.append(key_info) + for conn in model.assigned_couts: key_info = _extract_key_info(model, submodel_dict, conn) - assigned_canonicals["couts"].append(key_info) + assigned_couts.append(key_info) # Sort cins and couts in order to get exactly the same json for # the same model. - assigned_canonicals["cins"].sort() - assigned_canonicals["couts"].sort() + assigned_cins.sort(key=lambda x: str(x[-1])) # Simple sort criteria. + assigned_couts.sort(key=lambda x: str(x[-1])) return ( shapes_info, types_info, differentiability_info, constraints_info, - assigned_canonicals, + assigned_cins, + assigned_couts, ) -def _construct_info[T]( +def _extract_connection_from_info_key( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + info_key: tuple[str, int], +) -> ConnectionData: + model_name, index = info_key + return _extract_connection_from_index(model, model_name, index, submodel_dict) + + +def _construct_args_info( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + args: set[ConnectionData | str], + info: tuple[str, int] | str, +) -> None: + if isinstance(info, tuple): + connection = _extract_connection_from_info_key(model, submodel_dict, info) + args.add(connection) + elif isinstance(info, str): + args.add(info) + else: + raise RuntimeError("Unknown info format!") + + +def _construct_config_kwargs_info[T]( model: BaseModel, submodel_dict: dict[str, BaseModel], config: dict[ConnectionData, T], @@ -228,10 +303,7 @@ def _construct_info[T]( ) -> None: info_key, data = info if isinstance(info_key, tuple): - model_name, index = info_key - connection = _extract_connection_from_index( - model, model_name, index, submodel_dict - ) + connection = _extract_connection_from_info_key(model, submodel_dict, info_key) config[connection] = data elif isinstance(info_key, str): kwargs[info_key] = data @@ -256,28 +328,35 @@ def _set_assigned_info( types_info: list[tuple[tuple[str, int] | str, type]], diffs_info: list[tuple[tuple[str, int] | str, bool]], constraints_info: list[AssignedConstraintType], - canonicals_info: dict[str, list[tuple[str, int] | str]], + cins_info: list[tuple[str, int] | str], + couts_info: list[tuple[str, int] | str], ) -> None: # Shapes conversion. for shp_info in shapes_info: shapes: dict[ConnectionData, ShapeTemplateType] = {} shape_kwargs: dict[str, ShapeTemplateType] = {} for sub_info in shp_info: - _construct_info(model, submodel_dict, shapes, shape_kwargs, sub_info) + _construct_config_kwargs_info( + model, submodel_dict, shapes, shape_kwargs, sub_info + ) model.set_shapes(shapes, **shape_kwargs) # Types conversion. types_config: dict[ConnectionData, type] = {} types_kwargs: dict[str, type] = {} for type_info in types_info: - _construct_info(model, submodel_dict, types_config, types_kwargs, type_info) + _construct_config_kwargs_info( + model, submodel_dict, types_config, types_kwargs, type_info + ) model.set_types(types_config, **types_kwargs) # Differentiability settings for models and keys. diff_config: dict[ConnectionData, bool] = {} diff_kwargs: dict[str, bool] = {} for diff_info in diffs_info: - _construct_info(model, submodel_dict, diff_config, diff_kwargs, diff_info) + _construct_config_kwargs_info( + model, submodel_dict, diff_config, diff_kwargs, diff_info + ) model.set_differentiability(diff_config, **diff_kwargs) # Constraints. @@ -294,22 +373,10 @@ def _set_assigned_info( # Canonical keys cins: set[ConnectionData | str] = set() couts: set[ConnectionData | str] = set() - for canonical_type, info in canonicals_info.items(): - related_set = cins if canonical_type == "cins" else couts - for item in info: - if canonical_type not in ["cins", "couts"]: - raise RuntimeError("Unknown canonical key type!") - - if isinstance(item, tuple): - model_name, index = item - connection = _extract_connection_from_index( - model, model_name, index, submodel_dict - ) - related_set.add(connection) - elif isinstance(item, str): - related_set.add(item) - else: - raise RuntimeError("Unknown canonical key info format!") + for item in cins_info: + _construct_args_info(model, submodel_dict, cins, item) + for item in couts_info: + _construct_args_info(model, submodel_dict, couts, item) if cins: model.set_cin(*cins) if couts: @@ -440,7 +507,8 @@ def dict_to_model( assigned_shapes = params.get("assigned_shapes", []) assigned_differentiabilities = params.get("assigned_differentiabilities", []) assigned_constraints = params.get("assigned_constraints", []) - assigned_canonicals = params.get("assigned_canonicals", {}) + assigned_cins = params.get("assigned_cins", []) + assigned_couts = params.get("assigned_couts", []) _set_assigned_info( model, submodels_dict, @@ -448,7 +516,8 @@ def dict_to_model( assigned_types, assigned_differentiabilities, assigned_constraints, - assigned_canonicals, + assigned_cins, + assigned_couts, ) assert isinstance(model, Model) @@ -502,14 +571,16 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: assigned_types, assigned_differentiabilities, assigned_constraints, - assigned_canonicals, + assigned_cins, + assigned_couts, ) = _serialize_assigned_info(model, submodel_obj_dict) model_dict |= {"assigned_shapes": assigned_shapes} model_dict |= {"assigned_types": assigned_types} model_dict |= {"assigned_differentiabilities": assigned_differentiabilities} model_dict |= {"assigned_constraints": assigned_constraints} - model_dict |= {"assigned_canonicals": assigned_canonicals} + model_dict |= {"assigned_cins": assigned_cins} + model_dict |= {"assigned_couts": assigned_couts} return model_dict diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 3ad5703a..19809f97 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -17,7 +17,11 @@ from mithril import Backend, Constant, compile, epsilon_table from mithril.framework.common import IOHyperEdge, Tensor from mithril.models import BaseModel, Model, Operator, TrainModel -from mithril.utils.dict_conversions import dict_to_model, model_to_dict +from mithril.utils.dict_conversions import ( + dict_to_model, + extract_model_key_index, + model_to_dict, +) from tests.scripts.test_utils import ( assert_all_conn_key_are_same, convert_to_array, @@ -220,13 +224,13 @@ def assert_models_equal(model1: BaseModel, model2: BaseModel): model1_cins = set() model2_cins = set() for conn in model1.conns.cins: - model, key_index = model1.extract_model_key_index(conn) + model, key_index = extract_model_key_index(model1, conn) model_index = ( list(model1.dag.keys()).index(model) if model != model1 else "self" ) model1_cins.add((model_index, key_index)) for conn in model2.conns.cins: - model, key_index = model2.extract_model_key_index(conn) + model, key_index = extract_model_key_index(model2, conn) model_index = ( list(model2.dag.keys()).index(model) if model != model2 else "self" ) @@ -236,13 +240,13 @@ def assert_models_equal(model1: BaseModel, model2: BaseModel): model1_couts = set() model2_couts = set() for conn in model1.conns.couts: - model, key_index = model1.extract_model_key_index(conn) + model, key_index = extract_model_key_index(model1, conn) model_index = ( list(model1.dag.keys()).index(model) if model != model1 else "self" ) model1_couts.add((model_index, key_index)) for conn in model2.conns.couts: - model, key_index = model2.extract_model_key_index(conn) + model, key_index = extract_model_key_index(model2, conn) model_index = ( list(model2.dag.keys()).index(model) if model != model2 else "self" ) diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index 6b86076b..54f77c3c 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -169,6 +169,57 @@ def test_linear_not_expose(): ) +def test_set_cins_couts(): + model = Model() + linear_1 = Linear(dimension=42) + model |= linear_1(input="input", weight="weight", output=IOKey(name="output")) + model.set_cin("weight", linear_1.bias) + outer_model = Model() + linear_2 = Linear(dimension=42) + outer_model |= model(input="input_1", output=IOKey(name="output_1")) + outer_model |= linear_2(input="input_2", output=IOKey(name="output_2")) + outer_model.set_cout("output_2") + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert ( + model_dict_created["assigned_cins"] # type: ignore + == model_dict_recreated["assigned_cins"] # type: ignore + == [] + ) + assert ( + model_dict_created["assigned_couts"] # type: ignore + == model_dict_recreated["assigned_couts"] # type: ignore + == [("m_1", 4)] + ) + assert ( + model_dict_created["submodels"]["m_0"]["assigned_cins"] # type: ignore + == model_dict_recreated["submodels"]["m_0"]["assigned_cins"] # type: ignore + == [("self", 2), "weight"] + ) + assert ( + model_dict_created["submodels"]["m_0"]["assigned_couts"] # type: ignore + == model_dict_recreated["submodels"]["m_0"]["assigned_couts"] # type: ignore + == [] + ) + + assert_models_equal(outer_model, model_recreated) + + backend = JaxBackend(dtype=mithril.float64) + assert_evaluations_equal( + outer_model, + model_recreated, + backend, + static_keys={ + "input_1": backend.ones([4, 256]), + "input_2": backend.ones([4, 256]), + }, + ) + + def test_constant_key(): model = Model() model | Add()(left="input", right=Tensor(3), output=IOKey(name="output"))