diff --git a/examples/flux/t5.py b/examples/flux/t5.py index e4c78575..e8fbf58b 100644 --- a/examples/flux/t5.py +++ b/examples/flux/t5.py @@ -180,9 +180,9 @@ def load_t5_encoder( repo_id: str = "black-forest-labs/FLUX.1-schnell", max_len: int = 256, ) -> ml.models.PhysicalModel: - config = hf_hub_download(repo_id, "text_encoder_2/config.json") + config_path = hf_hub_download(repo_id, "text_encoder_2/config.json") - with open(config) as f: + with open(config_path) as f: config = json.load(f) t5 = t5_encode(config, name="encoder") diff --git a/examples/flux/util.py b/examples/flux/util.py index c6a583f9..0c7ada72 100644 --- a/examples/flux/util.py +++ b/examples/flux/util.py @@ -159,11 +159,11 @@ def load_flow_model(name: str, backend: ml.Backend, hf_download: bool = True): ckpt_path = configs[name].ckpt_path if ( ckpt_path is None - and configs[name].repo_id is not None - and configs[name].repo_flow is not None + and (r_id := configs[name].repo_id) is not None + and (r_flow := configs[name].repo_flow) is not None and hf_download ): - ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + ckpt_path = hf_hub_download(r_id, r_flow) flux_lm = flux(configs[name].params) flux_pm = ml.compile( @@ -187,11 +187,11 @@ def load_decoder( ckpt_path = configs[name].ae_path if ( ckpt_path is None - and configs[name].repo_id is not None - and configs[name].repo_ae is not None + and (r_id := configs[name].repo_id) is not None + and (r_ae := configs[name].repo_ae) is not None and hf_download ): - ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + ckpt_path = hf_hub_download(r_id, r_ae) # Loading the autoencoder print("Init AE") diff --git a/mithril/cores/python/numpy/ops.py b/mithril/cores/python/numpy/ops.py index f4cce382..9ab75ea4 100644 --- a/mithril/cores/python/numpy/ops.py +++ b/mithril/cores/python/numpy/ops.py @@ -667,7 +667,7 @@ def scaled_dot_product_attention( ) L, S = query.shape[-2], key.shape[-2] - scale_factor = 1 / np.sqrt(query.shape[-1]) if scale is None else scale + scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale write_into_cache(cache, "scale_factor", scale_factor) attn_bias = np.zeros((L, S), dtype=query.dtype) if is_causal: diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 6988f228..e6baca95 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -1043,7 +1043,7 @@ def __init__( value: TensorValueType | ToBeDetermined = TBD, type: _TensorTypes = int | float | bool, shape: ShapeNode | None = None, - differentiable: bool = False, + differentiable: bool | None = None, ): if shape is None: # If shape is not provided, create a new shape with a Variadic root. @@ -1119,8 +1119,16 @@ def match(self, other: Tensor[int | float | bool]) -> Updates: updates |= non_valued.set_value(valued.value) self.differentiable = False other.differentiable = False - else: - self.differentiable |= other.differentiable + elif self.differentiable is None: + self.differentiable = other.differentiable + # Differentiable tensors can only be float type. + if self.differentiable: + updates |= self.set_type(float) + elif ( + other.differentiable is not None + and self.differentiable != other.differentiable + ): + raise ValueError("Differentiability mismatch!") # Match shapes. updates |= self.match_shapes(other.shape) updates.shape_updates.discard(other) @@ -1180,8 +1188,16 @@ def _temp_shape(self) -> ShapeRepr | None: return None @property - def differentiable(self) -> bool: - return isinstance(self._value, Tensor) and self._value.differentiable + def differentiable(self) -> bool | None: + if isinstance(self._value, Tensor): + return self._value.differentiable + elif self.is_scalar: + # Scalars are always non-differentiable. + return False + # Differentiability of polymorphic edges are defined + # as None. Depending on its instant type updates, it can + # become True or False (e.g Tensor or int type). + return None @property def tensors(self) -> set[Tensor[int | float | bool]]: @@ -1405,7 +1421,7 @@ def set_value( updates.value_updates.add(self) # Update new type without automatic tensor value creation. updates |= self.set_type(find_type(self._value), create_tensor=False) - if self.is_valued: + if self.is_tensor and self.is_valued: self.set_differentiability(False) return updates @@ -1431,13 +1447,19 @@ def match(self, other: IOHyperEdge) -> Updates: return updates - def set_differentiability(self, differentiable: bool) -> None: - if self.is_tensor: - assert isinstance(self._value, Tensor) - self._value.differentiable = differentiable - elif differentiable: + def set_differentiability(self, differentiable: bool) -> Updates: + if self.is_scalar and differentiable: raise ValueError("Non-tensor edges cannot be differentiable.") + updates = Updates() + if differentiable: + # Differentiable edges can only be Tensor[float] type. + updates |= self.set_type(Tensor[float]) + # Set differentiability of the _value if it is a Tensor. + if isinstance(self._value, Tensor): + self._value.differentiable = differentiable + return updates + def add_constraint(self, constraint: Constraint) -> None: for type in constraint.types: self.constraints[type].add(constraint) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 6c82ae7a..193beca3 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -75,7 +75,7 @@ def __init__( | ScalarType | None = None, expose: bool | None = None, - differentiable: bool = False, + differentiable: bool | None = None, interval: list[float | int] | None = None, ) -> None: # If shape is provided, type should be Tensor. @@ -166,15 +166,7 @@ def __hash__(self) -> int: return hash(id(self)) def set_differentiability(self, differentiable: bool = True) -> Updates: - updates = Updates() - # TODO: Move this method to Model class as set_shapes, set_types etc. - if self.metadata.is_tensor: - self.metadata.set_differentiability(differentiable) - elif differentiable: - updates |= self.metadata.set_type(Tensor[float]) - self.metadata.set_differentiability(differentiable) - - return updates + return self.metadata.set_differentiability(differentiable) BaseKey = ConnectionData @@ -380,15 +372,17 @@ def __init__( ) -> None: self.dag: dict[BaseModel, dict[str, ConnectionData]] = {} self._formula_key: str | None = formula_key - # TODO: maybe set it only to Operator / Model. self.parent: BaseModel | None = None - self.assigned_shapes: list[dict[str, ShapeTemplateType]] = [] + self.assigned_shapes: list[dict[ConnectionData, ShapeTemplateType]] = [] self.assigned_types: dict[ - str, + ConnectionData, type | UnionType | ScalarType | type[Tensor[int | float | bool]], ] = {} + self.assigned_differentiabilities: dict[ConnectionData, bool] = {} self.assigned_constraints: list[AssignedConstraintType] = [] + self.assigned_cins: set[ConnectionData] = set() + self.assigned_couts: set[ConnectionData] = set() self.conns = Connections() self.frozen_attributes: list[str] = [] self.dependency_map = DependencyMap(self.conns) @@ -510,12 +504,14 @@ def _add_connection( local_key: str, given_connection: ConnectionDataType, updates: Updates, + trace: bool, ) -> tuple[ConnectionData, Updates]: is_input = local_key in model.input_keys local_connection = model.conns.get_connection(local_key) assert local_connection is not None, "Connection is not found!" edge = local_connection.metadata is_not_valued = not edge.is_valued + set_diff = None d_map = self.dependency_map.local_output_dependency_map @@ -527,6 +523,7 @@ def _add_connection( | Tensor[int | float | bool] | NullConnection ) = NOT_GIVEN + set_type: type[Tensor[int | float | bool]] | ScalarType = ToBeDetermined is_new_connection = False match given_connection: @@ -545,6 +542,13 @@ def _add_connection( set_value = given_connection given_connection = self._create_connection(edge, None) assert isinstance(given_connection, ConnectionData) + + if ( + given_connection.metadata.differentiable is not None + and given_connection.metadata.differentiable != edge.differentiable + ): + set_diff = given_connection.metadata.differentiable + # Connection is given as a Connection object. if ( con_obj := self.conns.get_con_by_metadata(given_connection.metadata) @@ -554,6 +558,8 @@ def _add_connection( is_new_connection = True expose = given_connection.is_exposed outer_key = given_connection.get_key() + if set_value is NOT_GIVEN: + set_type = given_connection.metadata.edge_type if set_value is NOT_GIVEN and given_connection.metadata.value is not TBD: set_value = given_connection.metadata._value if outer_key is not None: @@ -579,15 +585,22 @@ def _add_connection( "Expose flag cannot be false when " "no value is provided for input keys!" ) + + # Set value or type if given. if not isinstance(set_value, NullConnection): updates |= con_obj.metadata.set_value(set_value) - elif ( - set_type := given_connection.metadata.edge_type - ) is not ToBeDetermined: - model.set_types({local_connection: set_type}) - - if given_connection.metadata.differentiable: - updates |= con_obj.set_differentiability(True) + elif set_type is not ToBeDetermined: + # Skip tracing if the local connection's type is already + # set to the given type. + trace &= set_type != local_connection.metadata.edge_type + model._set_types({local_connection: set_type}, trace=trace) + + # Set differentiability if given. + if set_diff is not None: + # No need to trace differentiability for valued and + # existing connections. + trace &= is_new_connection and not given_connection.metadata.is_valued + model._set_differentiability({local_connection: set_diff}, trace) else: if given_connection in model.conns.all.values(): @@ -655,6 +668,7 @@ def _add_connection( or con_obj in self.conns.output_connections ): self.conns.couts.add(con_obj) + return con_obj, updates def rename_key(self, connection: ConnectionData, key: str) -> None: @@ -782,11 +796,16 @@ def _merge_connections( # TODO: Deleted connection's 'key' attribute is not updated. # Consider updating it. + # Update assigned attributes with conn2 to conn1. + self._update_assigned_attributes(conn1, conn2) + return updates def extend( self, model: BaseModel | BaseModel, + trace: bool = True, + /, **kwargs: ConnectionDataType, ) -> None: # Check possible errors before the extension. @@ -836,7 +855,9 @@ def extend( ): value._expose = True - con_obj, _updates = self._add_connection(model, local_key, value, updates) + con_obj, _updates = self._add_connection( + model, local_key, value, updates, trace + ) updates |= _updates submodel_dag[local_key] = con_obj if tensors := con_obj.metadata.tensors: @@ -848,6 +869,12 @@ def extend( submodel_dag[key].key: template for key, template in shape_info.items() } + # Insert to self dag as a FrozenDict."" + # Since we update dag in merge_connections, we could not use FrozenDict. + self.dag[model] = model_dag = submodel_dag + + self.dependency_map.add_model_dag(model, model_dag) + # Set given shapes. self._set_shapes(**shape_info) # TODO: Should "trace" be set to True?. self.constraint_solver(updates) @@ -855,12 +882,6 @@ def extend( model.constraint_solver.clear() model.conns.connections_dict = {} - # Insert to self dag as a FrozenDict."" - # Since we update dag in merge_connections, we could not use FrozenDict. - self.dag[model] = model_dag = submodel_dag - - self.dependency_map.add_model_dag(model, model_dag) - # Update jittablity by using model's jittablity. self._jittable &= model.jittable @@ -1356,8 +1377,12 @@ def _create_connection( connection.metadata = metadata return connection - def set_differentiability( - self, config: dict[ConnectionData, bool] | None = None, /, **kwargs: bool + def _set_differentiability( + self, + config: dict[ConnectionData, bool] | None = None, + trace: bool = False, + /, + **kwargs: bool, ) -> None: updates = Updates() if config is None: @@ -1373,12 +1398,20 @@ def set_differentiability( elif isinstance(key, ConnectionData): if key not in self.conns.all.values(): raise KeyError(f"Connection {key} is not found in the model.") + conn_data = key + updates |= conn_data.set_differentiability(value) - updates |= key.set_differentiability(value) + if trace: + self.assigned_differentiabilities[conn_data] = value model = self._get_outermost_parent() model.constraint_solver(updates) + def set_differentiability( + self, config: dict[ConnectionData, bool] | None = None, /, **kwargs: bool + ) -> None: + self._set_differentiability(config, True, **kwargs) + def _set_shapes( self, shapes: Mapping[ConnectionData, ShapeTemplateType] | None = None, @@ -1387,7 +1420,7 @@ def _set_shapes( **kwargs: ShapeTemplateType, ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. - assigned_shapes: dict[str, ShapeTemplateType] = {} + assigned_shapes: dict[ConnectionData, ShapeTemplateType] = {} updates = Updates() if shapes is None: shapes = {} @@ -1406,7 +1439,10 @@ def _set_shapes( assert conn is not None inner_key = conn.key shape_nodes[key] = (given_repr.node, inner_key) - assigned_shapes[inner_key] = shape + # In order to store assigned shapes, we need to store corresponding model + # and index of the connection for that model. + assigned_shapes[conn] = shape + # Apply updates to the shape nodes. for key in chain(shapes, kwargs): assert isinstance(key, str | ConnectionData) @@ -1444,12 +1480,6 @@ def _set_types( ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. if config is None: config = {} - - assigned_types: dict[ - str, - type | UnionType | ScalarType | type[Tensor[int | float | bool]], - ] = {} - # Get the outermost parent as all the updates will happen here. model = self._get_outermost_parent() updates = Updates() @@ -1458,12 +1488,12 @@ def _set_types( metadata = self.conns.extract_metadata(key) conn = self.conns.get_con_by_metadata(metadata) assert conn is not None - inner_key = conn.key - assigned_types[inner_key] = key_type updates |= metadata.set_type(key_type) - if trace: - # Store assigned types in the model. - self.assigned_types |= assigned_types + if trace: + # Store assigned types in the model. + if key_type is Tensor: + key_type = Tensor[int | float | bool] + self.assigned_types[conn] = key_type # Run the constraints for updating affected connections. model.constraint_solver(updates) @@ -1568,6 +1598,7 @@ def _add_constraint( keys: list[str], types: list[UpdateType] | None = None, dependencies: set[Constraint] | None = None, + trace: bool = False, ) -> Constraint: all_conns = self.conns.all hyper_edges = [all_conns[key].metadata for key in keys] @@ -1587,6 +1618,10 @@ def _add_constraint( hyper_edge.add_constraint(constr) self.constraint_solver.solver_loop({constr}) + + if trace: + self.assigned_constraints.append({"fn": fn.__name__, "keys": keys}) + return constr def add_constraint( @@ -1596,8 +1631,7 @@ def add_constraint( type: list[UpdateType] | None = None, dependencies: set[Constraint] | None = None, ) -> Constraint: - self.assigned_constraints.append({"fn": fn.__name__, "keys": keys}) - return self._add_constraint(fn, keys, type, dependencies) + return self._add_constraint(fn, keys, type, dependencies, True) @property def cin(self) -> ConnectionData: @@ -1619,6 +1653,11 @@ def cout(self) -> ConnectionData: return next(iter(self.conns.couts)) def set_cin(self, *connections: str | ConnectionData, safe: bool = True) -> None: + self._set_cin(*connections, safe=safe, trace=True) + + def _set_cin( + self, *connections: str | ConnectionData, safe: bool = True, trace: bool = False + ) -> None: self.conns.cins = set() for given_conn in connections: conn = self.conns.get_extracted_connection(given_conn) @@ -1637,8 +1676,15 @@ def set_cin(self, *connections: str | ConnectionData, safe: bool = True) -> None ) else: self.conns.cins.add(conn) + if trace: + self.assigned_cins.add(conn) def set_cout(self, *connections: str | ConnectionData, safe: bool = True) -> None: + self._set_cout(*connections, safe=safe, trace=True) + + def _set_cout( + self, *connections: str | ConnectionData, safe: bool = True, trace: bool = False + ) -> None: self.conns.couts = set() for given_conn in connections: conn = self.conns.get_extracted_connection(given_conn) @@ -1651,6 +1697,8 @@ def set_cout(self, *connections: str | ConnectionData, safe: bool = True) -> Non ) else: self.conns.couts.add(conn) + if trace: + self.assigned_couts.add(conn) def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: l_type = left.edge_type @@ -1687,6 +1735,45 @@ def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: # Match data of each IOHyperEdge's. return left.match(right) + def _update_assigned_attributes( + self, new: ConnectionData, old: ConnectionData + ) -> None: + """ + Update assigned attributes by replacing occurrences of the old ConnectionData + with the new ConnectionData. + + This method updates the following attributes: + - assigned_shapes: Replaces old ConnectionData with new ConnectionData + in the assigned shapes. + - assigned_types: Replaces old ConnectionData with new ConnectionData + in the assigned types. + - assigned_differentiabilities: Replaces old ConnectionData with new + ConnectionData in the assigned differentiabilities. + - assigned_canonicals: Updates the 'cins' and 'couts' sets by removing the + old ConnectionData and adding the new ConnectionData. + + Args: + new (ConnectionData): The new ConnectionData to replace the old one. + old (ConnectionData): The old ConnectionData to be replaced. + """ + for shape_info in self.assigned_shapes: + if old in shape_info: + shape_info[new] = shape_info.pop(old) + if old in self.assigned_types: + self.assigned_types[new] = self.assigned_types.pop(old) + + if old in self.assigned_differentiabilities: + self.assigned_differentiabilities[new] = ( + self.assigned_differentiabilities.pop(old) + ) + # Assigned canonicals + if old in self.assigned_cins: + self.assigned_cins.remove(old) + self.assigned_cins.add(new) + elif old in self.assigned_couts: + self.assigned_couts.remove(old) + self.assigned_couts.add(new) + class DependencyMap: """ diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index cc098ba3..70041bb1 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -536,7 +536,10 @@ def cin(self) -> Connection: return cin def _extend( - self, model: BaseModel, kwargs: Mapping[str, ConnectionType] | None = None + self, + model: BaseModel, + kwargs: Mapping[str, ConnectionType] | None = None, + trace: bool = True, ) -> Self: if kwargs is None: kwargs = {} @@ -560,7 +563,7 @@ def _extend( kwargs[key] = _value # type: ignore kwargs[key] = self._unroll_template(kwargs[key]) # type: ignore - self.extend(model, **kwargs) # type: ignore + self.extend(model, trace, **kwargs) # type: ignore return self def __add__(self, info: ExtendInfo | Model) -> Self: diff --git a/mithril/framework/logical/operator.py b/mithril/framework/logical/operator.py index 85cb9295..4a964696 100644 --- a/mithril/framework/logical/operator.py +++ b/mithril/framework/logical/operator.py @@ -106,7 +106,7 @@ def __init__( ) canonical_input_conn = self.conns.get_connection(canonical_input_key) if canonical_input_conn is not None: - self.set_cin(canonical_input_conn, safe=False) + self._set_cin(canonical_input_conn, safe=False) canonical_output_key = ( "output" @@ -115,7 +115,7 @@ def __init__( ) canonical_output_conn = self.conns.get_connection(canonical_output_key) if canonical_output_conn is not None: - self.set_cout(canonical_output_conn, safe=False) + self._set_cout(canonical_output_conn, safe=False) self._freeze() @property @@ -130,6 +130,8 @@ def class_name(self) -> str: def extend( self, model: BaseModel | BaseModel, + trace: bool = True, + /, **kwargs: ConnectionDataType, ) -> None: raise NotImplementedError("Operators cannot be extended!") diff --git a/mithril/framework/logical/operators.py b/mithril/framework/logical/operators.py index e2480459..c3f885b5 100644 --- a/mithril/framework/logical/operators.py +++ b/mithril/framework/logical/operators.py @@ -163,7 +163,7 @@ def __init__( fn=to_tuple_constraints, keys=[Operator.output_key] + [key for key in self.input_keys], ) - self.set_cin() + self._set_cin() class PowerOp(Operator): @@ -301,7 +301,7 @@ def __init__( types=[UpdateType.TYPE], dependencies={edge_constraint}, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class SubtractOp(Operator): @@ -410,7 +410,7 @@ def __init__( types=[UpdateType.TYPE], dependencies={edge_constraint}, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class MinimumOp(Operator): @@ -451,7 +451,7 @@ def __init__( keys=[Operator.output_key, "left", "right"], types=[UpdateType.TYPE], ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class MaximumOp(Operator): @@ -492,7 +492,7 @@ def __init__( keys=[Operator.output_key, "left", "right"], types=[UpdateType.TYPE], ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class DivideOp(Operator): @@ -851,7 +851,7 @@ def __init__( fn=to_list_constraints, keys=[Operator.output_key] + [key for key in self.input_keys], ) - self.set_cin() + self._set_cin() class TensorToListOp(Operator): @@ -1298,7 +1298,7 @@ def __init__( super().__init__( formula_key="equal", name=name, operator=operator.eq, left=left, right=right ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class NotEqualOp(RelationalOperatorsOp): @@ -1318,7 +1318,7 @@ def __init__( left=left, right=right, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class LessEqualOp(RelationalOperatorsOp): @@ -1428,7 +1428,7 @@ def __init__( keys=[Operator.output_key, "left", "right"], dependencies={edge_constraint}, ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) class LogicalAndOp(BitwiseOperatorsOp): @@ -1684,7 +1684,7 @@ def __init__( self._add_constraint( fn=slice_constraints, keys=["output", "start", "stop", "step"] ) - self.set_cin() + self._set_cin() class IndexerOp(Operator): diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index f543dc50..6eb7e51b 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -30,8 +30,9 @@ def __init__( name: str | None = None, ) -> None: super().__init__(name=name, enforce_jit=model._jittable) - self._extend(model, {k: k for k in model.external_keys}) + self._extend(model, {k: k for k in model.external_keys}, trace=False) self.expose_keys(*model.external_keys) + self._freeze() @property def submodel(self) -> Operator: @@ -42,11 +43,13 @@ def submodel(self) -> Operator: def extend( self, model: BaseModel, + trace: bool = True, + /, **kwargs: ConnectionDataType, ) -> None: if len(self.dag) > 0: raise RuntimeError("Primitive models cannot have submodels.") - super().extend(model, **kwargs) + super().extend(model, trace, **kwargs) class PrimitiveModel(OperatorModel): diff --git a/mithril/models/models.py b/mithril/models/models.py index f07e7dfd..6b99fe9f 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -200,7 +200,7 @@ def __init__( output=IOKey(name="output"), ) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -301,7 +301,7 @@ def __init__( dilation=dt_converter.output, output=IOKey(name="output"), ) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -397,7 +397,7 @@ def __init__( conv_connections["bias"] = IOKey("bias", differentiable=True) self |= PrimitiveConvolution1D(use_bias=use_bias)(**conv_connections) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -491,7 +491,7 @@ def __init__( conv_connections["bias"] = IOKey("bias", differentiable=True) self |= PrimitiveConvolution2D(use_bias=use_bias)(**conv_connections) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -558,7 +558,7 @@ def __init__( self |= mult(left=input_key, right=weight_key, output=output) self._set_shapes(**shapes) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -605,7 +605,7 @@ def __init__( right=IOKey(name="bias", value=bias), output=IOKey(name="output"), ) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -648,7 +648,7 @@ def __init__( bias=IOKey("bias", value=bias), ) self += activation(output=IOKey(name="output")) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -718,7 +718,7 @@ def __init__( add._set_shapes(**shapes) # TODO: Remove below Buffer after required naming-related changes are done. self |= Buffer()(input=self.cout, output=IOKey(name="output")) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -799,7 +799,7 @@ def __init__( add._set_shapes(**shapes) self |= Buffer()(input=self.cout, output=IOKey(name="output")) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -842,8 +842,8 @@ def __init__( self |= abs_model(input=IOKey("input", value=input)) self += Sum()(output=IOKey(name="output")) - self.set_cin("input", safe=False) - self.set_cout("output", safe=False) + self._set_cin("input", safe=False) + self._set_cout("output", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -870,8 +870,8 @@ def __init__( self += Square()(input=IOKey("input", value=input)) self += Sum() self += Multiply()(right=Tensor(0.5), output=IOKey(name="output")) - self.set_cin("input", safe=False) - self.set_cout("output", safe=False) + self._set_cin("input", safe=False) + self._set_cout("output", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -910,7 +910,7 @@ def __init__( ) shapes: dict[str, ShapeTemplateType] = {"input": [1, "N"], "kernel": ["N", "N"]} self._set_shapes(**shapes) - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -983,7 +983,7 @@ def __init__( } self._set_shapes(**shapes) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -1044,7 +1044,7 @@ def __init__( fn=polynomial_kernel_constraint, keys=["poly_coef", "degree"], ) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -1107,6 +1107,8 @@ def __init__( output=IOKey(name="output"), ) + # TODO: It is not clear where these "input1" and "input2" names come from. + # It assumes kernel model has two inputs named "input1" and "input2". shapes: dict[str, ShapeTemplateType] = { "input1": ["N", "d_in"], "input2": ["M", "d_in"], @@ -1116,7 +1118,7 @@ def __init__( "kernel": ["N", "M"], } self._set_shapes(**shapes) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -1164,7 +1166,7 @@ def __init__( ) self += decision_model(output=IOKey(name="decision_output")) - self.set_cout(linear_model.output) + self._set_cout(linear_model.output) self._freeze() def __call__( # type: ignore[override] @@ -1214,7 +1216,7 @@ def __init__( input=linear_model.output, output=IOKey(name="probs_output") ) - self.set_cout(linear_model.output) + self._set_cout(linear_model.output) self._freeze() def __call__( # type: ignore[override] @@ -1296,7 +1298,7 @@ def __init__( # Add current layer to the model. self += current_layer(**kwargs) prev_layer = current_layer - self.set_cin("input", safe=False) + self._set_cin("input", safe=False) self._freeze() def __call__( @@ -1415,8 +1417,8 @@ def __init__( } self._set_shapes(**shapes) - self.set_cin("input", safe=False) - self.set_cout("output") + self._set_cin("input", safe=False) + self._set_cout("output") self._freeze() def __call__( # type: ignore[override] @@ -1590,8 +1592,8 @@ def __init__( } self._set_shapes(**shapes) - self.set_cin("input", safe=False) - self.set_cout("output") + self._set_cin("input", safe=False) + self._set_cout("output") self._freeze() def __call__( # type: ignore[override] @@ -1870,8 +1872,8 @@ def __init__( ) prev_cell = current_cell - self.set_cin("input") - self.set_cout(current_cell.output) + self._set_cin("input") + self._set_cout(current_cell.output) self._freeze() def __call__( @@ -1996,8 +1998,8 @@ def __init__( input=concat_input_args, output=IOKey(name="hidden_concat", value=hidden_concat), ) - self.set_cin("input0") - self.set_cout("hidden_concat") + self._set_cin("input0") + self._set_cout("hidden_concat") self._freeze() def __call__( @@ -2053,7 +2055,7 @@ def __init__( initial_hidden=permutation_model.output, **(dec_input_mapping | dec_output_mapping), ) - self.set_cout(decoder.cout) + self._set_cout(decoder.cout) self._freeze() @@ -2101,7 +2103,7 @@ def __init__( initial_hidden=encoder.hidden_concat, **(dec_input_mapping | dec_output_mapping), ) - self.set_cout(decoder.cout) + self._set_cout(decoder.cout) self._freeze() def __call__(self, **model_keys: ConnectionType) -> ExtendInfo: @@ -2155,7 +2157,7 @@ def __init__( norm=modifier_model.output, output=IOKey(name="output"), ) - self.set_cin("input1", "input2", safe=False) + self._set_cin("input1", "input2", safe=False) self._freeze() def __call__( # type: ignore[override] @@ -2379,8 +2381,8 @@ def __init__( ) self._set_shapes(distances=["N", "N"], pred_distances=["N", "N"]) - self.set_cin("distances", safe=False) - self.set_cout("output") + self._set_cin("distances", safe=False) + self._set_cout("output") self._freeze() def __call__( # type: ignore[override] @@ -2717,7 +2719,7 @@ def __init__( output=IOKey(name="confidence", value=confidence), ) - self.set_cout(pred_model.output) + self._set_cout(pred_model.output) shapes: dict[str, ShapeTemplateType] = { "label": ["N", 1], "s": [1], @@ -2898,7 +2900,7 @@ def __init__( self |= Buffer()(input=label_key, output=IOKey("label_formatted")) self |= Buffer()(input=result, output=IOKey("output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -2960,7 +2962,7 @@ def __init__( denominator=n_prediction.tensor(), output=IOKey(name="output"), ) - self.set_cin(self.pred) + self._set_cin(self.pred) def __call__( # type: ignore[override] self, @@ -3119,7 +3121,7 @@ def __init__( self |= Buffer()(input=precision, output=IOKey(name="output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -3278,7 +3280,7 @@ def __init__( self |= Buffer()(input=recall, output=IOKey(name="output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -3442,7 +3444,7 @@ def __init__( self |= Buffer()(input=precision, output=IOKey(name="output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] @@ -3503,7 +3505,7 @@ def __init__( self |= Buffer()(auc_score, IOKey("output")) - self.set_cin(self.pred) + self._set_cin(self.pred) self._freeze() def __call__( # type: ignore[override] diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index d7beb19f..10eb2eb7 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -1748,7 +1748,7 @@ def __init__( fn=general_tensor_type_constraint, keys=[Operator.output_key, "left", "right", "norm"], ) - self.set_cin("left", "right", safe=False) + self._set_cin("left", "right", safe=False) def __call__( # type: ignore[override] self, @@ -2068,7 +2068,7 @@ def __init__( step=BaseKey(type=int | float, value=step), dtype=BaseKey(type=types.Dtype | None, value=dtype), ) - self.set_cin("stop", safe=False) + self._set_cin("stop", safe=False) if not all_defined: self._add_constraint( @@ -2493,7 +2493,7 @@ def __init__( fn=general_tensor_type_constraint, keys=[Operator.output_key, "input1", "input2"], ) - self.set_cin("input1", safe=False) + self._set_cin("input1", safe=False) def __call__( # type: ignore[override] self, diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 18a2d6be..ec565d00 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -46,7 +46,7 @@ ) from ..models.train_model import TrainModel from ..utils import model_conversion_lut -from ..utils.utils import convert_to_tuple +from ..utils.utils import OrderedSet, convert_to_tuple class KeyDict(TypedDict, total=False): @@ -80,16 +80,17 @@ class RegDict(TypedDict): class ModelDict(TypedDict, total=False): name: str args: dict[str, Any] - assigned_shapes: dict[str, ShapeTemplateType] - differentiability_info: dict[str, bool] + assigned_shapes: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]] + assigned_types: list[tuple[tuple[str, int] | str, str]] + assigned_differentiabilities: list[tuple[tuple[str, int] | str, bool]] assigned_constraints: list[AssignedConstraintType] + assigned_cins: list[tuple[str, int] | str] + assigned_couts: list[tuple[str, int] | str] tuples: list[str] enums: dict[str, str] unnamed_keys: list[str] - types: dict[str, str] submodels: dict[str, ModelDict] connections: dict[str, dict[str, str | ConnectionDict]] - canonical_keys: dict[str, tuple[set[str], set[str]]] class TrainModelDict(TypedDict): @@ -125,6 +126,263 @@ class TrainModelDict(TypedDict): enum_dict = {"PaddingType": PaddingType} +def extract_model_key_index( + model: BaseModel, connection: ConnectionData +) -> tuple[BaseModel, int]: + """ + Extracts the model and key index for a given connection. + + This function determines the model and the index of the key associated with + that model. It handles both frozen and non-frozen states of the model. + + Args: + connection (ConnectionData): The connection data for which the model and + key index need to be extracted. + + Returns: + tuple[BaseModel, int]: A tuple containing the model and the index + of the key for that model. + + Raises: + KeyError: If the connection is not found in the model. + """ + if connection not in model.conns.all.values(): + raise KeyError(f"Connection {connection.key} is not found in the model") + + if model.is_frozen: + _model = model + name = connection._name + assert name is not None + key_index = list(model.external_keys).index(name) + else: + assert connection.model is not None + key = connection.key + if (_model := connection.model) is model: + # Always return info using freezed models. + model_info: ( + list[tuple[BaseModel, OrderedSet[ConnectionData]]] + | tuple[BaseModel, OrderedSet[ConnectionData]] + ) = model.dependency_map.local_input_dependency_map.get( + connection, + model.dependency_map.local_output_dependency_map.get(connection), # type: ignore + ) + _model = model_info[0][0] if isinstance(model_info, list) else model_info[0] + # Get corresponding connection of the _model. + _model_conn = _model.conns.get_con_by_metadata(connection.metadata) + assert _model_conn is not None + key = _model_conn.key + key_index = list(_model.external_keys).index(key) + + return (_model, key_index) + + +def _extract_key_info( + model: BaseModel, submodel_dict: dict[BaseModel, str], conn: ConnectionData +) -> tuple[str, int] | str: + m, key_index = extract_model_key_index(model, conn) + m_name = submodel_dict[m] if m is not model else "self" + + key_name = list(model.external_keys)[key_index] + if m_name == "self" and not key_name.startswith("$"): + # If corresponding connection is named, save info + # only with this name and shape. + key_info: tuple[str, int] | str = key_name + else: + key_info = (m_name, key_index) + + return key_info + + +def _serialize_assigned_info( + model: BaseModel, submodel_dict: dict[BaseModel, str] +) -> tuple[ + list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]], + list[tuple[tuple[str, int] | str, str]], + list[tuple[tuple[str, int] | str, bool]], + list[AssignedConstraintType], + list[tuple[str, int] | str], + list[tuple[str, int] | str], +]: + shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]] = [] + types_info: list[tuple[tuple[str, int] | str, str]] = [] + differentiability_info: list[tuple[tuple[str, int] | str, bool]] = [] + constraints_info: list[AssignedConstraintType] = [] + + # Shapes info. + for shp_info in model.assigned_shapes: + info_list: list[tuple[tuple[str, int] | str, ShapeTemplateType]] = [] + for conn, shape_info in shp_info.items(): + # Key info. + key_info = _extract_key_info(model, submodel_dict, conn) + # Shape info. + shape_list: list[int | str | tuple[str, EllipsisType] | None] = [] + for item in shape_info: + if isinstance(item, tuple): # variadic + shape_list.append(f"{item[0]},...") + else: + shape_list.append(item) + # Combine key info with shape info. + info_list.append((key_info, shape_list)) + shapes_info.append(info_list) + + # Types info. + for conn, typ in model.assigned_types.items(): + # Key info. + key_info = _extract_key_info(model, submodel_dict, conn) + # Combine key info with type info. + if get_origin(typ) is Tensor: + types_info.append((key_info, "tensor")) + elif typ is not ToBeDetermined: + types_info.append((key_info, str(typ.__name__))) # type: ignore + + # Differentiability info. + for conn, status in model.assigned_differentiabilities.items(): + # Key info. + key_info = _extract_key_info(model, submodel_dict, conn) + # Combine key info with differentiability status. + differentiability_info.append((key_info, status)) + + # Constraints. + for constraint in model.assigned_constraints: + constraints_info.append(constraint) + + # Canonical keys. + assigned_cins: list[tuple[str, int] | str] = [] + assigned_couts: list[tuple[str, int] | str] = [] + for conn in model.assigned_cins: + key_info = _extract_key_info(model, submodel_dict, conn) + assigned_cins.append(key_info) + for conn in model.assigned_couts: + key_info = _extract_key_info(model, submodel_dict, conn) + assigned_couts.append(key_info) + # Sort cins and couts in order to get exactly the same json for + # the same model. + assigned_cins.sort(key=lambda x: str(x[-1])) # Simple sort criteria. + assigned_couts.sort(key=lambda x: str(x[-1])) + + return ( + shapes_info, + types_info, + differentiability_info, + constraints_info, + assigned_cins, + assigned_couts, + ) + + +def _extract_connection_from_info_key( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + info_key: tuple[str, int], +) -> ConnectionData: + model_name, index = info_key + return _extract_connection_from_index(model, model_name, index, submodel_dict) + + +def _construct_args_info( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + args: set[ConnectionData | str], + info: tuple[str, int] | str, +) -> None: + if isinstance(info, tuple): + connection = _extract_connection_from_info_key(model, submodel_dict, info) + args.add(connection) + elif isinstance(info, str): + args.add(info) + else: + raise RuntimeError("Unknown info format!") + + +def _construct_config_kwargs_info[T]( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + config: dict[ConnectionData, T], + kwargs: dict[str, T], + info: tuple[tuple[str, int] | str, T], +) -> None: + info_key, data = info + if isinstance(info_key, tuple): + connection = _extract_connection_from_info_key(model, submodel_dict, info_key) + config[connection] = data + elif isinstance(info_key, str): + kwargs[info_key] = data + else: + raise RuntimeError("Unknown info format!") + + +def _extract_connection_from_index( + model: BaseModel, model_name: str, index: int, submodel_dict: dict[str, BaseModel] +) -> ConnectionData: + _model = model if model_name == "self" else submodel_dict[model_name] + connection_key = list(_model.external_keys)[index] + connection = _model.conns.get_connection(connection_key) + assert connection is not None + return connection + + +def _set_assigned_info( + model: BaseModel, + submodel_dict: dict[str, BaseModel], + shapes_info: list[list[tuple[tuple[str, int] | str, ShapeTemplateType]]], + types_info: list[tuple[tuple[str, int] | str, type]], + diffs_info: list[tuple[tuple[str, int] | str, bool]], + constraints_info: list[AssignedConstraintType], + cins_info: list[tuple[str, int] | str], + couts_info: list[tuple[str, int] | str], +) -> None: + # Shapes conversion. + for shp_info in shapes_info: + shapes: dict[ConnectionData, ShapeTemplateType] = {} + shape_kwargs: dict[str, ShapeTemplateType] = {} + for sub_info in shp_info: + _construct_config_kwargs_info( + model, submodel_dict, shapes, shape_kwargs, sub_info + ) + model.set_shapes(shapes, **shape_kwargs) + + # Types conversion. + types_config: dict[ConnectionData, type] = {} + types_kwargs: dict[str, type] = {} + for type_info in types_info: + _construct_config_kwargs_info( + model, submodel_dict, types_config, types_kwargs, type_info + ) + model.set_types(types_config, **types_kwargs) + + # Differentiability settings for models and keys. + diff_config: dict[ConnectionData, bool] = {} + diff_kwargs: dict[str, bool] = {} + for diff_info in diffs_info: + _construct_config_kwargs_info( + model, submodel_dict, diff_config, diff_kwargs, diff_info + ) + model.set_differentiability(diff_config, **diff_kwargs) + + # Constraints. + for constr_info in constraints_info: + constrain_fn = constr_info["fn"] + if constrain_fn not in constrain_fn_dict: + raise RuntimeError( + "In the process of creating a model from a dictionary, an unknown" + " constraint function was encountered!" + ) + constrain_fn = constrain_fn_dict[constrain_fn] + model.add_constraint(constrain_fn, keys=constr_info["keys"]) # type: ignore + + # Canonical keys + cins: set[ConnectionData | str] = set() + couts: set[ConnectionData | str] = set() + for item in cins_info: + _construct_args_info(model, submodel_dict, cins, item) + for item in couts_info: + _construct_args_info(model, submodel_dict, couts, item) + if cins: + model.set_cin(*cins) + if couts: + model.set_cout(*couts) + + def create_iokey_kwargs( info: KeyDict, submodels_dict: dict[str, BaseModel] ) -> dict[str, Any]: @@ -207,27 +465,10 @@ def dict_to_model( } model = type(model_name, (Model,), attrs)(**args) - types: dict[str, str] = params.get("types", {}) - # TODO: Set all types in a bulk. - set_types = {} - for key, typ in types.items(): - if typ == "tensor": - set_types[key] = Tensor[int | float | bool] - # else: - # # TODO: Get rid of using eval method. Find more secure - # # way to convert strings into types and generic types. - # set_types[key] = eval(typ) - unnamed_keys: list[str] = params.get("unnamed_keys", []) - differentiability_info: dict[str, bool] = params.get("differentiability_info", {}) - assigned_shapes = params.get("assigned_shapes", {}) - assigned_constraints = params.get("assigned_constraints", []) - canonical_keys: dict[str, tuple[set[str], set[str]]] = params.get( - "canonical_keys", {} - ) assert isinstance(model, Model) - submodels_dict = {} + submodels_dict: dict[str, BaseModel] = {} for m_key, v in submodels.items(): m = dict_to_model(v) submodels_dict[m_key] = m @@ -243,7 +484,7 @@ def dict_to_model( elif isinstance(conn, dict): if (io_key := conn.get("key")) is not None: # TODO: Update this part according to new IOKey structure. - key_kwargs = create_iokey_kwargs(io_key, submodels_dict) # type: ignore + key_kwargs = create_iokey_kwargs(io_key, submodels_dict) if (conns := key_kwargs.pop("connections", None)) is not None: mappings[k] = conns.pop() mergings[k] = conns @@ -258,33 +499,27 @@ def dict_to_model( con = getattr(m, key) model.merge_connections(con, *conns) - if set_types: - model.set_types(**set_types) - - if "model" in canonical_keys: - if canonical_keys["model"][0] is not None: - model.set_cin(*canonical_keys["model"][0]) - if canonical_keys["model"][1] is not None: - model.set_cout(*canonical_keys["model"][1]) - - for key, value in differentiability_info.items(): - con = model.conns.get_connection(key) - assert con is not None - con.set_differentiability(value) - - if len(assigned_constraints) > 0: - for constr_info in assigned_constraints: - constrain_fn = constr_info["fn"] - if constrain_fn not in constrain_fn_dict: - raise RuntimeError( - "In the process of creating a model from a dictionary, an unknown" - " constraint function was encountered!" - ) - constrain_fn = constrain_fn_dict[constrain_fn] - model.add_constraint(constrain_fn, keys=constr_info["keys"]) # type: ignore + # Set all assigned info. + assigned_types = [ + (info, Tensor if typ == "tensor" else eval(typ)) + for info, typ in params.get("assigned_types", []) + ] + assigned_shapes = params.get("assigned_shapes", []) + assigned_differentiabilities = params.get("assigned_differentiabilities", []) + assigned_constraints = params.get("assigned_constraints", []) + assigned_cins = params.get("assigned_cins", []) + assigned_couts = params.get("assigned_couts", []) + _set_assigned_info( + model, + submodels_dict, + assigned_shapes, + assigned_types, + assigned_differentiabilities, + assigned_constraints, + assigned_cins, + assigned_couts, + ) - if len(assigned_shapes) > 0: - model.set_shapes(**dict_to_shape(assigned_shapes)) assert isinstance(model, Model) return model @@ -295,82 +530,59 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: model_name = model.__class__.__name__ args = handle_model_to_dict_args(model_name, model.factory_args) - assigned_shapes: dict[str, ShapeTemplateType] = {} - differentiability_info: dict[str, bool] = {} - assigned_constraints: list[AssignedConstraintType] = [] - types: dict[str, str] = {} - - for key, con in model.conns.all.items(): - edge = con.metadata - if edge.is_tensor and not con.is_autogenerated: - differentiability_info[key] = edge.differentiable - for shape in model.assigned_shapes: - assigned_shapes |= shape_to_dict(shape) - - for constrain in model.assigned_constraints: - assigned_constraints.append(constrain) + model_dict: ModelDict = { + "name": model_name, + "args": args, + } - for key, typ in model.assigned_types.items(): - if get_origin(typ) is Tensor: - types[key] = "tensor" - # elif typ is not ToBeDetermined: - # types[key] = str(typ) - - if ( - model_name != "Model" - and model_name in dir(models) - or model_name not in dir(models) - ): - model_dict: ModelDict = { - "name": model_name, - "args": args, - "assigned_shapes": assigned_shapes, - "differentiability_info": differentiability_info, - "assigned_constraints": assigned_constraints, - "types": types, - } - return model_dict + submodel_obj_dict: dict[BaseModel, str] = {} - connection_dict: dict[str, dict[str, str | ConnectionDict]] = {} - canonical_keys: dict[str, tuple[set[str], set[str]]] = {} - submodels: dict[str, ModelDict] = {} + if model_name == "Model" and model_name in dir(models): + connection_dict: dict[str, dict[str, str | ConnectionDict]] = {} + submodels: dict[str, ModelDict] = {} - # IOHyperEdge -> [model_id, connection_name] - submodel_connections: dict[IOHyperEdge, list[str]] = {} - assert isinstance(model, Model) + # IOHyperEdge -> [model_id, connection_name] + submodel_connections: dict[IOHyperEdge, list[str]] = {} + assert isinstance(model, Model) - for idx, submodel in enumerate(model.dag.keys()): - model_id = f"m_{idx}" - submodels[model_id] = model_to_dict(submodel) # type: ignore + for idx, submodel in enumerate(model.dag.keys()): + model_id = f"m_{idx}" + submodels[model_id] = model_to_dict(submodel) # type: ignore + submodel_obj_dict[submodel] = model_id - # Store submodel connections - for key in submodel._all_keys: - submodel_connections.setdefault( - submodel.conns.get_metadata(key), [model_id, key] + # Store submodel connections + for key in submodel._all_keys: + submodel_connections.setdefault( + submodel.conns.get_metadata(key), [model_id, key] + ) + assert isinstance(model, Model) + connection_dict[model_id] = connection_to_dict( + model, submodel, submodel_connections, model_id ) - assert isinstance(model, Model) - connection_dict[model_id] = connection_to_dict( - model, submodel, submodel_connections, model_id - ) - canonical_keys[model_id] = ( - get_keys(submodel.conns.cins), - get_keys(submodel.conns.couts), - ) - canonical_keys["model"] = (get_keys(model.conns.cins), get_keys(model.conns.couts)) - composite_model_dict: ModelDict = { - "name": model_name, - "args": args, - "assigned_shapes": assigned_shapes, - "differentiability_info": differentiability_info, - "assigned_constraints": assigned_constraints, - "types": types, - "connections": connection_dict, - "canonical_keys": canonical_keys, - "submodels": submodels, - } - return composite_model_dict + model_dict |= { + "connections": connection_dict, + "submodels": submodels, + } + + ( + assigned_shapes, + assigned_types, + assigned_differentiabilities, + assigned_constraints, + assigned_cins, + assigned_couts, + ) = _serialize_assigned_info(model, submodel_obj_dict) + + model_dict |= {"assigned_shapes": assigned_shapes} + model_dict |= {"assigned_types": assigned_types} + model_dict |= {"assigned_differentiabilities": assigned_differentiabilities} + model_dict |= {"assigned_constraints": assigned_constraints} + model_dict |= {"assigned_cins": assigned_cins} + model_dict |= {"assigned_couts": assigned_couts} + + return model_dict def get_keys(canonicals: set[ConnectionData]) -> set[str]: diff --git a/tests/json_files/integration_directed_test.json b/tests/json_files/integration_directed_test.json index 5232cd00..1235a9b4 100644 --- a/tests/json_files/integration_directed_test.json +++ b/tests/json_files/integration_directed_test.json @@ -2402,10 +2402,9 @@ "args": { "kernel": { "name": "PolynomialKernel", - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_types": [ + ["poly_coef", "tensor"] + ] } } }, @@ -2580,7 +2579,10 @@ "test_distance_matrix_1": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], "args": { "get_final_distance": false } @@ -2625,7 +2627,10 @@ "test_distance_matrix_2": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], "args": { "get_final_distance": false } @@ -2667,7 +2672,10 @@ "test_distance_matrix_3": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], "args": { "get_final_distance": false } @@ -2717,7 +2725,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1.0, 4.0, 9.0], @@ -2756,7 +2767,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1.0, 4.0], @@ -2792,7 +2806,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 3.0, 27.0, 48.0], @@ -2831,7 +2848,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 4.0], @@ -2864,7 +2884,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 4.0, 1.0], @@ -2901,7 +2924,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 25.0], @@ -2935,7 +2961,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1e+32], @@ -2969,7 +2998,10 @@ "name": "MDSCore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1e-32], @@ -3004,7 +3036,10 @@ "args": { "prediction_dim": 3, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "static_keys": { "input": [[1.0,1.0,0.0,2.0], [2.0,2.0,1.0,3.0]], @@ -3037,7 +3072,10 @@ "name": "TSNECore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 1.0, 9.0], [1.0, 0.0, 4.0], [9.0, 4.0, 0.0]], @@ -3071,7 +3109,10 @@ }, "test_tsne_core_2": { "model": { - "name": "TSNECore" + "name": "TSNECore", + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]], @@ -3105,7 +3146,10 @@ }, "test_tsne_core_3": { "model": { - "name": "TSNECore" + "name": "TSNECore", + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "static_keys": { "distances": [[0.0, 16.0], [16.0, 0.0]], @@ -3140,7 +3184,10 @@ "args": { "prediction_dim": 1, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "static_keys": { "input": [[5.0], [1.0]], @@ -3175,7 +3222,10 @@ "args": { "prediction_dim": 1, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "static_keys": { "input": [[0.0], [1.0], [3.0]], @@ -3209,11 +3259,11 @@ "test_polynomial_kernel_1": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input2", true], + ["poly_coef", true], + ["degree", true] + ] }, "inputs": { "input2": [[-1.0, 2.0], [1.0, 3.0]], @@ -3250,11 +3300,11 @@ "test_polynomial_kernel_2": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input1": true, "input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input1", true], + ["poly_coef", true], + ["degree", true] + ] }, "inputs": { "input1": [[-1.0, 2.0, 1.0], [1.0, 3.0, 2.0], [-2, 1, -1]], diff --git a/tests/json_files/models_directed_test.json b/tests/json_files/models_directed_test.json index 8c7f054f..539698f0 100644 --- a/tests/json_files/models_directed_test.json +++ b/tests/json_files/models_directed_test.json @@ -2,7 +2,7 @@ "test_linear_1": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { @@ -27,7 +27,7 @@ "test_linear_2": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -51,7 +51,7 @@ "test_linear_3": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2], [0.2, 0.3], [0.4, 0.3]], @@ -75,7 +75,7 @@ "test_linear_4": { "model": { "name": "Linear", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2], [0.2, 0.5], [0.3, 0.1]], @@ -99,7 +99,7 @@ "test_layernorm_1": { "model": { "name": "LayerNorm", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2], [0.2, 0.5], [0.3, 0.1]], @@ -127,7 +127,7 @@ "test_layernorm_2": { "model": { "name": "LayerNorm", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2, 3.0], [-5, 0.2, 0.5], [0.3,12.0, 0.1]], @@ -155,7 +155,12 @@ "test_RBFKernel_1": { "model": { "name": "RBFKernel", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["l_scale", true], + ["sigma", true] + ] }, "inputs": { "input1": [[1.0], [2.0], [3.0]], @@ -186,7 +191,12 @@ "description": "We use this setting as input to test_SVMLayerWithReg_2", "model": { "name": "RBFKernel", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["l_scale", true], + ["sigma", true] + ] }, "inputs": { "input1": [[1.0], [2.0], [3.0]], @@ -216,11 +226,12 @@ "test_PolynomialKernel": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input1": true, "input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["poly_coef", true], + ["degree", true] + ] }, "inputs": { "input1": [[1.0, 2], [2, 3], [3, 4]], @@ -273,7 +284,7 @@ "test_logistic_1_logits_ignored": { "model": { "name": "LogisticRegression", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -298,7 +309,7 @@ "test_logistic_2_no_ignore": { "model": { "name": "LogisticRegression", - "differentiability_info": {"input": true} + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -324,10 +335,14 @@ "test_distance_matrix_1": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, "args": { "get_final_distance": false - } + }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0], [2], [3]], @@ -355,7 +370,11 @@ "test_distance_matrix_2": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0], [2], [3]], @@ -383,7 +402,11 @@ "test_distance_matrix_3": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0, 2.0], @@ -419,7 +442,11 @@ "test_distance_matrix_4": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true}, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ], "args": { "get_final_distance": false } @@ -451,7 +478,11 @@ "test_distance_matrix_5": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[1.0], [3], [4]], @@ -480,7 +511,11 @@ "test_distance_matrix_6": { "model": { "name": "EncoderDistanceMatrix", - "differentiability_info": {"input1": true, "input2": true} + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["norm", true] + ] }, "inputs": { "input1": [[3.0, 2], [4, 1], [4, 3]], @@ -515,7 +550,8 @@ "name": "PolynomialFeatures", "args": { "degree": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0, 3.0]] @@ -537,7 +573,8 @@ "name": "PolynomialFeatures", "args": { "degree": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2], [3], [4]] @@ -559,7 +596,8 @@ "name": "PolynomialFeatures", "args": { "degree": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 1], [2, 3], [3, -1], [0, 10]] @@ -587,7 +625,8 @@ "name": "PolynomialFeatures", "args": { "degree": 3 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2], [2, 1]] @@ -609,12 +648,13 @@ "test_mds_core_1": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -651,12 +691,13 @@ "test_mds_core_2": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -689,12 +730,13 @@ "test_mds_core_3": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -727,12 +769,13 @@ "test_mds_core_4": { "model": { "name": "MDSCore", - "differentiability_info": {"distances": true}, + "assigned_differentiabilities": [ + ["distances", true], + ["pred_distances", true], + ["norm", true] + ], "args": { "exact_distances": false - }, - "types": { - "norm": "tensor" } }, "inputs": { @@ -764,7 +807,11 @@ "NOTE": "Prove the differences of diagonal entries of input gradient wrt composite-ml is because of linearized models like power, log etc.", "model": { "name": "MDS", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [ + ["input", true], + ["coords", true], + ["norm", true] + ], "args": { "prediction_dim": 1, "input_type": "powered_distances" @@ -802,7 +849,9 @@ "test_tsne_core_1": { "model": { "name": "TSNECore", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [ + ["pred_distances", true] + ], "args": { "exact_distances": false } @@ -844,7 +893,10 @@ "name": "TSNECore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "inputs": { "pred_distances": [[0.0, 16, 1], [16, 0, 9], [1, 9, 0]] @@ -886,7 +938,10 @@ "name": "TSNECore", "args": { "exact_distances": false - } + }, + "assigned_differentiabilities": [ + ["pred_distances", true] + ] }, "inputs": { "pred_distances": [[0.0, 16], [16, 0]] @@ -922,7 +977,10 @@ "args": { "prediction_dim": 1, "input_type": "points" - } + }, + "assigned_differentiabilities": [ + ["coords", true] + ] }, "discard_keys": ["predicted_coords"], "inputs": { @@ -950,7 +1008,9 @@ "test_polynomial_1D_Degree_1": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [ + ["input", true] + ], "args": { "degree": 1, "dimension": 1 @@ -978,7 +1038,7 @@ "test_polynomial_2D_Degree_1": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 1, "dimension": 1 @@ -1006,7 +1066,7 @@ "test_polynomial_1D_Degree_2": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 2, "dimension": 1 @@ -1034,7 +1094,7 @@ "test_polynomial_1D_Degree_3": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 3, "dimension": 1 @@ -1062,7 +1122,7 @@ "test_polynomial_2D_Degree_2": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 2, "dimension": 1 @@ -1090,7 +1150,7 @@ "test_polynomial_2D_Degree_2_multilabel": { "model": { "name": "PolynomialRegression", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "degree": 2, "dimension": 1 @@ -2070,7 +2130,7 @@ "test_maxpool1d_1": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 } @@ -2101,7 +2161,7 @@ "test_maxpool1d_2": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2130,7 +2190,7 @@ "test_maxpool1d_3": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2162,7 +2222,7 @@ "test_maxpool1d_4": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [0,1] @@ -2196,7 +2256,7 @@ "test_maxpool1d_1_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 } @@ -2227,7 +2287,7 @@ "test_maxpool1d_2_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2256,7 +2316,7 @@ "test_maxpool1d_3_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2288,7 +2348,7 @@ "test_maxpool1d_4_2d_input": { "model": { "name": "MaxPool1D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [0,1] @@ -2321,7 +2381,7 @@ "test_maxpool2d_1": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 }, @@ -2363,7 +2423,7 @@ "test_maxpool2d_2": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 } @@ -2492,7 +2552,7 @@ "test_maxpool2d_3": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2529,7 +2589,7 @@ "test_maxpool2d_4": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2575,7 +2635,7 @@ "test_maxpool2d_5": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [[1,0],[0,1]] @@ -2621,7 +2681,7 @@ "test_maxpool2d_1_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2 }, @@ -2664,7 +2724,7 @@ "test_maxpool2d_2_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 3 } @@ -2701,7 +2761,7 @@ "test_maxpool2d_3_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "stride": 1 @@ -2747,7 +2807,7 @@ "test_maxpool2d_4_3d_input": { "model": { "name": "MaxPool2D", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "kernel_size": 2, "padding": [[1,0],[0,1]] @@ -2791,7 +2851,8 @@ }, "test_where_1": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": { "cond": [[[ @@ -2840,7 +2901,8 @@ }, "test_where_2": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": { "cond": [[[ @@ -2889,7 +2951,8 @@ }, "test_where_3": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": {"cond": [true]}, "inputs": { @@ -2932,7 +2995,8 @@ }, "test_where_4": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": {"cond": [true]}, "inputs": { @@ -2971,7 +3035,8 @@ }, "test_where_5": { "model": { - "name": "Where" + "name": "Where", + "assigned_differentiabilities": [["input1", true], ["input2", true]] }, "static_keys": {"cond": [false]}, "inputs": { @@ -3011,7 +3076,7 @@ "test_nn_1": { "model": { "name": "MLP", - "differentiability_info": {"input": true}, + "assigned_differentiabilities": [["input", true]], "args": { "dimensions": [2, 2, 2], "activations": ["relu", "relu", "buffer"] @@ -3047,9 +3112,7 @@ "test_buffer_1": { "model": { "name": "Buffer", - "types": { - "input": "tensor" - } + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]] @@ -3071,9 +3134,7 @@ "model": { "name": "Buffer", - "types": { - "input": "tensor" - } + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [2.0, 0.0]] @@ -3095,9 +3156,7 @@ "model": { "name": "Buffer", - "types": { - "input": "tensor" - } + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [2.0, 0.0]] @@ -3119,7 +3178,8 @@ "test_matmul_1": { "model": { - "name": "MatrixMultiply" + "name": "MatrixMultiply", + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0, 5.0]], @@ -3142,10 +3202,7 @@ "test_mult_1": { "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0], [2.0], [3.0], [4.0]], @@ -3170,10 +3227,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [2.0], @@ -3198,10 +3252,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3226,10 +3277,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3254,10 +3302,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0], [3.0, 4.0]], @@ -3282,10 +3327,7 @@ "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3308,10 +3350,7 @@ "test_mult_7": { "model": { "name": "Multiply", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0],[2],[3]], @@ -3333,10 +3372,7 @@ "test_div_1": { "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [[1.0], [2.0], [3.0], [4.0]], @@ -3361,10 +3397,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [2.0], @@ -3389,10 +3422,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3417,10 +3447,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [1.0, 2.0], @@ -3445,10 +3472,7 @@ "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [[1.0, 2.0], [3.0, 4.0]], @@ -3471,10 +3495,7 @@ "test_div_6": { "model": { "name": "Divide", - "types": { - "numerator": "tensor", - "denominator": "tensor" - } + "assigned_differentiabilities": [["numerator", true], ["denominator", true]] }, "inputs": { "numerator": [1.0, 2.0], @@ -3497,10 +3518,7 @@ "test_sum_1": { "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0], [2.0], [3.0], [4.0]], @@ -3525,10 +3543,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [2.0], @@ -3553,10 +3568,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3581,10 +3593,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3609,10 +3618,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0], [3.0, 4.0]], @@ -3637,10 +3643,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3664,10 +3667,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[[2.0,3,4]]]]], @@ -3692,10 +3692,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[[[[[[[[[1.0, 1.0],[1.0, 1.0]]]]]]]]]]]], @@ -3720,10 +3717,7 @@ "model": { "name": "Add", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[[[1.0]]]],[[[[1.0]]]]], [[[[[1.0]]]], [[[[1.0]]]]]], @@ -3747,10 +3741,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0], [2.0], [3.0], [4.0]], @@ -3775,10 +3766,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [2.0], @@ -3803,10 +3791,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3831,10 +3816,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3859,10 +3841,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[1.0, 2.0], [3.0, 4.0]], @@ -3887,10 +3866,7 @@ "model": { "name": "Subtract", - "types": { - "left": "tensor", - "right": "tensor" - } + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0], @@ -3915,10 +3891,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1.0], [2.0], [3.0], [4.0]], @@ -3943,10 +3916,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [2.0], @@ -3971,10 +3941,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], @@ -3999,10 +3966,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -4027,10 +3991,7 @@ "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1.0, 2.0], [3.0, 4.0]], @@ -4054,10 +4015,7 @@ "test_power_6": { "model": { "name": "Power", - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -4080,7 +4038,8 @@ "test_exp": { "model": { - "name": "Exponential" + "name": "Exponential", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0]] @@ -4100,7 +4059,8 @@ }, "test_sqrt_1": { "model": { - "name": "Sqrt" + "name": "Sqrt", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0]] @@ -4120,7 +4080,8 @@ }, "test_sqrt_2": { "model": { - "name": "Sqrt" + "name": "Sqrt", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[4.0, 16.0], [25.0, 100.0]]]] @@ -4140,7 +4101,8 @@ }, "test_sqrt_3": { "model": { - "name": "Sqrt" + "name": "Sqrt", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [10000.0] @@ -4162,7 +4124,8 @@ "test_abs": { "model": { - "name": "Absolute" + "name": "Absolute", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, 0.0], [1.0, -2.0]] @@ -4187,7 +4150,8 @@ "test_sin": { "model": { - "name": "Sine" + "name": "Sine", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.5235987755982988, 1.0471975511965976], [1.0, 45.0], [90.0, 145.0]] @@ -4207,7 +4171,8 @@ "test_cos": { "model": { - "name": "Cosine" + "name": "Cosine", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.5235987755982988, 1.0471975511965976], [1.0, 45.0], [90.0, 145.0]] @@ -4249,7 +4214,8 @@ "test_flatten": { "model": { - "name": "Flatten" + "name": "Flatten", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0, 3.0, 4.0],[5.0, 6.0, 7.0, 8.0]] @@ -4269,7 +4235,8 @@ "test_flatten2": { "model": { - "name": "Flatten" + "name": "Flatten", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4293,7 +4260,8 @@ "args": { "start_dim": 1, "end_dim": -1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4317,7 +4285,8 @@ "args": { "start_dim": -1, "end_dim": -1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4341,7 +4310,8 @@ "args": { "start_dim": 1, "end_dim": 2 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4363,7 +4333,8 @@ "test_transpose": { "model": { - "name": "Transpose" + "name": "Transpose", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] @@ -4388,7 +4359,8 @@ "args": { "axes": [0,2,1] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0, 3.0], [2.0, 4.0]], [[1.0, 3.0], [2.0, 4.0]], [[1.0, 3.0], [2.0, 4.0]]] @@ -4414,7 +4386,8 @@ "args": { "axes": [0] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [3.0] @@ -4439,7 +4412,8 @@ "args": { "axes": [1,4,3,2,0] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[2.0]]]], [[[[3.0]]]]] @@ -4464,7 +4438,8 @@ "args": { "axes": [2,4,3,0,1] }, - "tuples": ["axes"] + "tuples": ["axes"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[2.0]]]], [[[[3.0]]]]] @@ -4584,7 +4559,8 @@ "test_tanh_1": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[10.0]] @@ -4606,7 +4582,8 @@ "test_tanh_2": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[30.0]] @@ -4628,7 +4605,8 @@ "test_tanh_3": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.220446049250313e-16]] @@ -4650,7 +4628,8 @@ "test_tanh_4": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4672,7 +4651,8 @@ "test_tanh_5": { "model": { - "name": "Tanh" + "name": "Tanh", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4695,7 +4675,8 @@ "test_sigmoid_1": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[20.0]] @@ -4717,7 +4698,8 @@ "test_sigmoid_2": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4739,7 +4721,8 @@ "test_sigmoid_3": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-30.0]] @@ -4761,7 +4744,8 @@ "test_sigmoid_4": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[919.78546867]] @@ -4783,7 +4767,8 @@ "test_sigmoid_5": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-919.78546867]] @@ -4805,7 +4790,8 @@ "test_sigmoid_6": { "model": { - "name": "Sigmoid" + "name": "Sigmoid", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4827,7 +4813,8 @@ "test_softplus_1": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0]] @@ -4849,7 +4836,8 @@ "test_softplus_2": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, 2.0, -30]] @@ -4871,7 +4859,8 @@ "test_softplus_3": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4893,7 +4882,8 @@ "test_softplus_4": { "model": { - "name": "Softplus" + "name": "Softplus", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -4914,7 +4904,8 @@ "test_permute_tensor_1": { "model": { - "name": "PermuteTensor" + "name": "PermuteTensor", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [ @@ -4952,7 +4943,8 @@ }, "test_permute_tensor_2": { "model": { - "name": "PermuteTensor" + "name": "PermuteTensor", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [ @@ -4990,7 +4982,11 @@ }, "test_squared_error_1": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [2], [3], [4]], @@ -5013,7 +5009,11 @@ "test_squared_error_2": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [-1000000.0], [0.000000000001], [-1.5]], @@ -5035,7 +5035,11 @@ }, "test_squared_error_3": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[0.1, 0.2], [1000000000.0, 0.0000000001]], @@ -5057,7 +5061,11 @@ }, "test_squared_error_4": { "model": { - "name": "SquaredError" + "name": "SquaredError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0, 2], [3, 4], [5, 6]], @@ -5079,7 +5087,8 @@ }, "test_hinge_loss_1": { "model": { - "name": "HingeLoss" + "name": "HingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0], [0.25]] @@ -5104,7 +5113,8 @@ "test_hinge_loss_2": { "model": { - "name": "HingeLoss" + "name": "HingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [-1.0]] @@ -5129,7 +5139,8 @@ }, "test_hinge_loss_3": { "model": { - "name": "HingeLoss" + "name": "HingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0], [0.0]] @@ -5151,7 +5162,11 @@ }, "test_absolute_error_1": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0]], @@ -5180,7 +5195,11 @@ }, "test_absolute_error_2": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0], [-1000000.0], [0.000000000001], [-1.5]], @@ -5210,7 +5229,11 @@ "test_absolute_error_3": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[0.1, 0.2], [1000000000.0, 0.0000000001]], @@ -5240,7 +5263,11 @@ "test_absolute_error_4": { "model": { - "name": "AbsoluteError" + "name": "AbsoluteError", + "assigned_differentiabilities": [ + ["input", true], + ["target", true] + ] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], @@ -5270,7 +5297,8 @@ "test_cross_entropy_1": { "model": { - "name": "CrossEntropy" + "name": "CrossEntropy", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 1.0, 1.0, 1.0]] @@ -5293,7 +5321,8 @@ "test_cross_entropy_2": { "model": { - "name": "CrossEntropy" + "name": "CrossEntropy", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1000.0, 0.0], [0.0, 1000.0]] @@ -5316,7 +5345,8 @@ "test_cross_entropy_3": { "model": { - "name": "CrossEntropy" + "name": "CrossEntropy", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 1.0], [0.0, 1000.0]] @@ -5341,7 +5371,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] @@ -5368,7 +5399,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] @@ -5399,7 +5431,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[2.220446049250313e-16, 1.0], [0.1, 0.9]] @@ -5426,7 +5459,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0]] @@ -5457,7 +5491,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.2, 1.1102230246251565e-16, 0.7999999999999998], [0.1, 0.6, 0.3]] @@ -5483,7 +5518,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.5, 0.5], [0.1, 0.9]] @@ -5510,7 +5546,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.5, 0.5], [0.1, 0.9]] @@ -5537,7 +5574,8 @@ "name": "CrossEntropy", "args": { "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0, 1.0], [0.1, 0.9]] @@ -5563,7 +5601,8 @@ "name": "CrossEntropy", "args": { "input_type": "log_probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[-0.6931471805599453, -0.6931471805599453], [-2.3025850929940455, -0.10536051565782628]] @@ -5589,7 +5628,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.1], [0.5]] @@ -5615,7 +5655,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.1, 0.2, 0.3], [0.5, 0.4, 0.2]] @@ -5645,7 +5686,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1e-20, 0.2, 0.3], [0.5, 0.4, 1e-20]] @@ -5675,7 +5717,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.2, 0.3], [0.5, 1.0, 0.2]] @@ -5706,7 +5749,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.1], [0.5]] @@ -5732,7 +5776,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.1102230246251565e-16], [0.5]] @@ -5758,7 +5803,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0102230246251565e-16], [0.5]] @@ -5784,7 +5830,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [ @@ -5825,7 +5872,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "probs" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.1102230246251565e-16], [0.95]] @@ -5851,7 +5899,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-36.83119036496738], [0.0]] @@ -5881,7 +5930,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0], [-2]] @@ -5907,7 +5957,8 @@ "name": "BinaryCrossEntropy", "args":{ "input_type": "logits" - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0, -2, 3, 0], [0, 1, 2, -1]] @@ -5939,7 +5990,8 @@ }, "test_quantile_loss_1": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0]] @@ -5962,7 +6014,8 @@ }, "test_quantile_loss_2": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "static_keys": { "quantile": 0.1, @@ -5985,7 +6038,8 @@ }, "test_quantile_loss_3": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [1e-5]] @@ -6008,7 +6062,8 @@ }, "test_quantile_loss_4": { "model": { - "name": "QuantileLoss" + "name": "QuantileLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0, 2.0], [0.0, 1.0]] @@ -6027,7 +6082,8 @@ }, "test_quad_hinge_loss_1": { "model": { - "name": "QuadHingeLoss" + "name": "QuadHingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[2.0], [0.25]] @@ -6049,7 +6105,8 @@ }, "test_quad_hinge_loss_2": { "model": { - "name": "QuadHingeLoss" + "name": "QuadHingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[1.0], [-1.0]] @@ -6072,7 +6129,8 @@ "test_quad_hinge_loss_3": { "model": { - "name": "QuadHingeLoss" + "name": "QuadHingeLoss", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input":[[0.0], [0.0]] @@ -6094,7 +6152,8 @@ }, "test_kl_div_1": { "model": { - "name": "KLDivergence" + "name": "KLDivergence", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.5], [0.5, 0.1]] @@ -6116,7 +6175,8 @@ }, "test_kl_div_3": { "model": { - "name": "KLDivergence" + "name": "KLDivergence", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.5], [0.5, 0.1]] @@ -6138,7 +6198,8 @@ }, "test_kl_div_4": { "model": { - "name": "KLDivergence" + "name": "KLDivergence", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.1102230246251565e-16, 0.5], [0.5, 0.1]] @@ -6160,7 +6221,8 @@ }, "test_relu_1": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6179,7 +6241,8 @@ }, "test_relu_2": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0]] @@ -6198,7 +6261,8 @@ }, "test_relu_3": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -6217,7 +6281,8 @@ }, "test_relu_4": { "model": { - "name": "Relu" + "name": "Relu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -6236,7 +6301,8 @@ }, "test_leaky_relu_1": { "model": { - "name": "LeakyRelu" + "name": "LeakyRelu", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"slope": 0.2}, "inputs": { @@ -6260,7 +6326,8 @@ }, "test_leaky_relu_2": { "model": { - "name": "LeakyRelu" + "name": "LeakyRelu", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"slope": 0.2}, "inputs": { @@ -6284,7 +6351,8 @@ }, "test_leaky_relu_3": { "model": { - "name": "LeakyRelu" + "name": "LeakyRelu", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"slope": 0.2}, "inputs": { @@ -6308,7 +6376,8 @@ }, "test_gelu_1": { "model": { - "name": "Gelu" + "name": "Gelu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6327,7 +6396,8 @@ }, "test_gelu_2": { "model": { - "name": "Gelu" + "name": "Gelu", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -6346,7 +6416,8 @@ }, "test_stop_gradient_1": { "model": { - "name": "StopGradient" + "name": "StopGradient", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6365,7 +6436,8 @@ }, "test_stop_gradient_2": { "model": { - "name": "StopGradient" + "name": "StopGradient", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [-0.04, -100]] @@ -6384,7 +6456,8 @@ }, "test_softmax_1": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, 2.0, -30]] @@ -6403,7 +6476,8 @@ }, "test_softmax_2": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.0, 0.0]] @@ -6422,7 +6496,8 @@ }, "test_softmax_3": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6441,7 +6516,8 @@ }, "test_softmax_4": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, 2.0, -30]] @@ -6460,7 +6536,8 @@ }, "test_softmax_5": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.1, 0.0, 0.0]] @@ -6479,7 +6556,8 @@ }, "test_softmax_6": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6498,7 +6576,8 @@ }, "test_softmax_7": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, -2.0], [2.0, 0.0]] @@ -6517,7 +6596,8 @@ }, "test_softmax_8": { "model": { - "name": "Softmax" + "name": "Softmax", + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"axis": 0}, "inputs": { @@ -6540,7 +6620,8 @@ "name": "Log", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[2.0, 2.0], [3.0, 4.0], [4.0, 100.0]] @@ -6562,7 +6643,8 @@ "name": "Log", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[0.0]]] @@ -6588,7 +6670,8 @@ "name": "Log", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1e-311, 1e-306]]] @@ -6611,7 +6694,8 @@ }, "test_stable_reciprocal_1": { "model": { - "name": "StableReciprocal" + "name": "StableReciprocal", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[2.0, 2.0], [3.0, 4.0], [4.0, 100.0]]]]] @@ -6630,7 +6714,8 @@ }, "test_stable_reciprocal_2": { "model": { - "name": "StableReciprocal" + "name": "StableReciprocal", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0]] @@ -6649,7 +6734,8 @@ }, "test_stable_reciprocal_3": { "model": { - "name": "StableReciprocal" + "name": "StableReciprocal", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1e-155, 1e-145]] @@ -6671,7 +6757,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0, 2.0], [3.0, 4.0]] @@ -6693,7 +6780,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[4.0, 16.0], [25.0, 100.0]]]] @@ -6715,7 +6803,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [10000.0] @@ -6737,7 +6826,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[0.0, -4.0], [-1.0, -4.0]] @@ -6763,7 +6853,8 @@ "name": "Sqrt", "args": { "robust": true - } + }, + "assigned_differentiabilities": [["input", true]] }, "static_keys": {"cutoff": 1e-20}, "inputs": { @@ -6791,10 +6882,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[0.0], [2.2250738585072014e-308]], @@ -6819,10 +6907,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[2.0], [3.0], [4.0]], @@ -6851,10 +6936,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [2.0], @@ -6885,10 +6967,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [2.0], @@ -6913,10 +6992,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -6942,10 +7018,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [1.0, 2.0], @@ -6971,10 +7044,7 @@ "args": { "robust": true }, - "types": { - "base": "tensor", - "exponent": "tensor" - } + "assigned_differentiabilities": [["base", true], ["exponent", true]] }, "inputs": { "base": [[1e-311, 1e-311]], @@ -7208,7 +7278,8 @@ }, "test_var_1": { "model": { - "name": "Variance" + "name": "Variance", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -7231,7 +7302,8 @@ "args": { "correction": 0.0, "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -7254,7 +7326,8 @@ "args": { "correction": 0.0, "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -7276,7 +7349,8 @@ "name": "Variance", "args": { "correction": 1.0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[1.0], [2.0], [3.0], [4.0], [5.0]] @@ -7295,7 +7369,8 @@ }, "test_reduce_sum_1": { "model": { - "name": "Sum" + "name": "Sum", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7314,7 +7389,8 @@ }, "test_reduce_sum_2": { "model": { - "name": "Sum" + "name": "Sum", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 3.0], [4.0, 5.0]] @@ -7336,7 +7412,8 @@ "name": "Sum", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7358,7 +7435,8 @@ "name": "Sum", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7380,7 +7458,8 @@ "name": "Sum", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7403,7 +7482,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7426,7 +7506,8 @@ "args": { "axis": [0,1,2,3,4] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]] @@ -7445,7 +7526,8 @@ }, "test_reduce_mean_1": { "model": { - "name": "Mean" + "name": "Mean", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7464,7 +7546,8 @@ }, "test_reduce_mean_2": { "model": { - "name": "Mean" + "name": "Mean", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 3.0], [4.0, 5.0]] @@ -7486,7 +7569,8 @@ "name": "Mean", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7508,7 +7592,8 @@ "name": "Mean", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7530,7 +7615,8 @@ "name": "Mean", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7553,7 +7639,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7576,7 +7663,8 @@ "args": { "axis": [0,1,2,3,4] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]] @@ -7595,7 +7683,8 @@ }, "test_reduce_max_1": { "model": { - "name": "Max" + "name": "Max", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7614,7 +7703,8 @@ }, "test_reduce_max_2": { "model": { - "name": "Max" + "name": "Max", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 6.0], [6.0, 5.0]] @@ -7636,7 +7726,8 @@ "name": "Max", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7658,7 +7749,8 @@ "name": "Max", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7680,7 +7772,8 @@ "name": "Max", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7703,7 +7796,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7726,7 +7820,8 @@ "args": { "axis": [0,1,2,3,4] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]] @@ -7745,7 +7840,8 @@ }, "test_reduce_min_1": { "model": { - "name": "Min" + "name": "Min", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[[[[[[1.0,2.0,3.0],[8.0,6.0,1.0]],[[7.0,6.0,1.0],[11.0,9.0,3.0]]]]]]]]]]]]]]] @@ -7764,7 +7860,8 @@ }, "test_reduce_min_2": { "model": { - "name": "Min" + "name": "Min", + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-7.0, -8.0], [6.0, 6.0], [6.0, 5.0]] @@ -7786,7 +7883,8 @@ "name": "Min", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7808,7 +7906,8 @@ "name": "Min", "args": { "axis": 1 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7830,7 +7929,8 @@ "name": "Min", "args": { "axis": 0 - } + }, + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[-1.0, -2.0], [2.0, 0.0], [1.0, -1.0]] @@ -7853,7 +7953,8 @@ "args": { "axis": [0,2] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]] @@ -7876,7 +7977,8 @@ "args": { "axis": [0,1,3,5,6] }, - "tuples": ["axis"] + "tuples": ["axis"], + "assigned_differentiabilities": [["input", true]] }, "inputs": { "input": [[[[[[[[[[1.0,2.0,3.0],[4.0,5.0,6.0]],[[3.0,2.0,1.0],[1.0,3.0,5.0]],[[6.0,2.0,5.0],[9.0,3.0,1.0]]]]]]]]]] @@ -7921,7 +8023,8 @@ "right": {"key": {"connect": [["m2","output"]]}}, "output": "output" } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -8058,7 +8161,8 @@ "expose": true }} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -8269,7 +8373,8 @@ "expose": true }} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [1.0, 2.0, 3.0], @@ -8451,7 +8556,8 @@ "expose": true }} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -8576,7 +8682,8 @@ "right": {"key": {"name": "right", "type": "tensor"}}, "output": {"key": {"connect": [["m2","right"]]}} } - } + }, + "assigned_differentiabilities": [["left", true], ["right", true]] }, "inputs": { "left": [[[[1.0, 2.0, 3.0]]]], @@ -8632,7 +8739,8 @@ } - } + }, + "assigned_differentiabilities": [["input1", true]] }, "inputs": { "input1": [[[[1.0, 2.0, 3.0]]]] @@ -8683,7 +8791,8 @@ "m1": { "input": {"key": {"connect": [["m2","input"]]}} } - } + }, + "assigned_differentiabilities": [["my_input", true]] }, "inputs": { "my_input": [[[[1.0, 2.0, 3.0]]]] @@ -8734,7 +8843,8 @@ "m1": { "input": {"key": {"connect": [["m2","input"]]}} } - } + }, + "assigned_differentiabilities": [["my_input", true]] }, "inputs": { "my_input": [[[[1.0, 2.0, 3.0]]]] @@ -8877,7 +8987,10 @@ "test_rnn_cell_1": { "model": { "name": "RNNCell", - "differentiability_info": {"input": true, "prev_hidden": true} + "assigned_differentiabilities": [ + ["input", true], + ["prev_hidden", true] + ] }, "inputs": { @@ -8916,7 +9029,10 @@ "test_rnn_cell_2": { "model": { "name": "RNNCell", - "differentiability_info": {"input": true, "prev_hidden": true} + "assigned_differentiabilities": [ + ["input", true], + ["prev_hidden", true] + ] }, "inputs": { @@ -8955,7 +9071,11 @@ "test_lstm_cell_1": { "model": { "name": "LSTMCell", - "differentiability_info": {"input": true, "prev_hidden": true, "prev_cell": true} + "assigned_differentiabilities": [ + ["input", true], + ["prev_hidden", true], + ["prev_cell", true] + ] }, "inputs": { diff --git a/tests/json_files/no_grad_inference_test.json b/tests/json_files/no_grad_inference_test.json index c99b9e1e..474fb0c9 100644 --- a/tests/json_files/no_grad_inference_test.json +++ b/tests/json_files/no_grad_inference_test.json @@ -19,6 +19,11 @@ "name": "Multiply" } }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["input3", true] + ], "connections": { "m1": { "left": {"key": {"name": "input1", "type": "tensor"}}, @@ -96,6 +101,11 @@ "name": "Multiply" } }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["input3", true] + ], "connections": { "m1": { "left": {"key": {"name": "input1", "type": "tensor"}}, @@ -172,6 +182,11 @@ "name": "Multiply" } }, + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true], + ["input3", true] + ], "connections": { "m1": { "left": {"key": {"name": "input1", "type": "tensor"}}, diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index 9a99b045..8a77b59f 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -1647,7 +1647,10 @@ "test_one_to_many_rnn": { "model": { "name": "OneToMany", - "differentiability_info": {"input": true, "initial_hidden": true}, + "assigned_differentiabilities": [ + ["input", true], + ["initial_hidden", true] + ], "regular_args": { "max_sequence_length": 4, "cell_type": "RNNCell" @@ -2086,11 +2089,13 @@ "test_polynomial_kernel": { "model": { "name": "PolynomialKernel", - "differentiability_info": {"input1": true, "input2": true}, - "types": { - "poly_coef": "tensor", - "degree": "tensor" - } + "assigned_differentiabilities": [ + ["input1", true], + ["input2", true] + ], + "assigned_types": [ + ["poly_coef", "tensor"] + ] }, "static_input_info": { "poly_coef": { diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 7f8cfa16..19809f97 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -17,7 +17,11 @@ from mithril import Backend, Constant, compile, epsilon_table from mithril.framework.common import IOHyperEdge, Tensor from mithril.models import BaseModel, Model, Operator, TrainModel -from mithril.utils.dict_conversions import dict_to_model, model_to_dict +from mithril.utils.dict_conversions import ( + dict_to_model, + extract_model_key_index, + model_to_dict, +) from tests.scripts.test_utils import ( assert_all_conn_key_are_same, convert_to_array, @@ -59,17 +63,15 @@ def evaluate_case( else: is_list = static_keys.get("is_list", False) static_keys[key] = convert_to_array(backend, value, is_list) + # Set logical tensor values for trainable keys if and only if - # model has any output keys. - if model.output_keys: + # model has list type inputs and has any output keys. + if is_inputs_list and model.output_keys: values: dict[str, Tensor[float] | list[Tensor[float]]] = {} for key, value in reference_gradients.items(): - if is_inputs_list: - values[key] = [ - Tensor[float](differentiable=True) for _ in range(len(value)) - ] - else: - values[key] = Tensor[float](differentiable=True) + values[key] = [ + Tensor[float](differentiable=True) for _ in range(len(value)) + ] model.set_values(**values) models.append(model) @@ -212,19 +214,44 @@ def assert_models_equal(model1: BaseModel, model2: BaseModel): model1_keys = model1.generate_keys() model2_keys = model2.generate_keys() - # if model1.cin is not None and model2.cin is not None: - # assert model1_keys.get( - # key := model1.cin.key, key - # ) == model2_keys.get(key := model2.cin.key, key) - # assert model1_keys.get( - # key := model1.cout.key, key - # ) == model2_keys.get(key := model2.cout.key, key) - assert {model1_keys.get(con.key, con.key) for con in model2.conns.cins} == { - model2_keys.get(con.key, con.key) for con in model2.conns.cins - } - assert {model1_keys.get(con.key, con.key) for con in model2.conns.couts} == { - model2_keys.get(con.key, con.key) for con in model2.conns.couts - } + # For exact match, we need to compare the keys together with + # model and key_index of corresponding key for the corresponding model. + # Model info is converted also to index, which is the index of the model + # in the DAG, if extracted model is not the model itself. If the extracted + # model is the model itself, then the index is "self". + + # Canonical input keys tests. + model1_cins = set() + model2_cins = set() + for conn in model1.conns.cins: + model, key_index = extract_model_key_index(model1, conn) + model_index = ( + list(model1.dag.keys()).index(model) if model != model1 else "self" + ) + model1_cins.add((model_index, key_index)) + for conn in model2.conns.cins: + model, key_index = extract_model_key_index(model2, conn) + model_index = ( + list(model2.dag.keys()).index(model) if model != model2 else "self" + ) + model2_cins.add((model_index, key_index)) + assert model1_cins == model2_cins + # Canonical output keys tests. + model1_couts = set() + model2_couts = set() + for conn in model1.conns.couts: + model, key_index = extract_model_key_index(model1, conn) + model_index = ( + list(model1.dag.keys()).index(model) if model != model1 else "self" + ) + model1_couts.add((model_index, key_index)) + for conn in model2.conns.couts: + model, key_index = extract_model_key_index(model2, conn) + model_index = ( + list(model2.dag.keys()).index(model) if model != model2 else "self" + ) + model2_couts.add((model_index, key_index)) + assert model1_couts == model2_couts # NOTE: Below assertions will be uncommented after converting # model's dag from topological order to insertion order. diff --git a/tests/scripts/summary_txts/test_physical_summary_17 b/tests/scripts/summary_txts/test_physical_summary_17 index 78c25e8e..873e0802 100644 --- a/tests/scripts/summary_txts/test_physical_summary_17 +++ b/tests/scripts/summary_txts/test_physical_summary_17 @@ -18,16 +18,16 @@ Total Parameters : >0 ------------------------------- - MatrixMultiply ----------------------------------------------------------------------------------------------------------- -Model Name | Model Keys - | --------------------------------------------------------------------------- - | Keys : Shapes : Types : Connections : Parameters -========================================================================================================== -MatrixMultiply | Inputs : left : [None, ..., None] : bool | float | int : 'input' : Unknown - | right : [None, 3] : float : 'output_0' : 0 - | ------------------------------------------------------------------------------------- - | Outputs : output : [None, ..., 3] : float : 'output_1' : 0 ----------------------------------------------------------------------------------------------------------- + MatrixMultiply +--------------------------------------------------------------------------------------------- +Model Name | Model Keys + | -------------------------------------------------------------- + | Keys : Shapes : Types : Connections : Parameters +============================================================================================= +MatrixMultiply | Inputs : left : [None, ..., None] : float : 'input' : Unknown + | right : [None, 3] : float : 'output_0' : 0 + | ------------------------------------------------------------------------ + | Outputs : output : [None, ..., 3] : float : 'output_1' : 0 +--------------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth b/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth index b5a14a3c..4fabd9ca 100644 --- a/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth +++ b/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth @@ -4,8 +4,8 @@ Model Name | Model Keys | -------------------------------------------------------------- | Keys : Shapes : Types : Connections ============================================================================================ -KernelizedSVM | Inputs : input1 : [u1, u2] : bool | float | int : '$input' - | input2 : [u3, u2] : bool | float | int : '$input2' +KernelizedSVM | Inputs : input1 : [u1, u2] : float : '$input' + | input2 : [u3, u2] : float : '$input2' | sigma : [ 1] : bool | float | int : '$sigma' | l_scale : [ 1] : bool | float | int : '$l_scale' | weight : [ 1, u3] : float : '$weight' @@ -34,13 +34,13 @@ Model Name | Model Keys | ---------------------------------------------------------- | Keys : Shapes : Types : Connections ===================================================================================== -RBFKernel | Inputs : input1 : [u1, u2] : bool | float | int : 'input1' - | input2 : [u3, u2] : bool | float | int : 'input2' +RBFKernel | Inputs : input1 : [u1, u2] : float : 'input1' + | input2 : [u3, u2] : float : 'input2' | $right : -- : float : -0.5 | sigma : [ 1] : bool | float | int : 'sigma' | l_scale : [ 1] : bool | float | int : 'l_scale' | -------------------------------------------------------------------- - | Outputs : output : [u1, u3] : bool | float | int : Linear.input + | Outputs : output : [u1, u3] : float : Linear.input | 'kernel' ------------------------------------------------------------------------------------- Linear | Inputs : weight : [ 1, u3] : float : 'weight' diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index f61c6e6f..0f66c0e8 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -145,8 +145,7 @@ def compile_and_compare( pm = mithril.compile( model, backend=backend, - **compile_kwargs - | {"constant_keys": statics, "trainable_keys": params.keys()}, + **compile_kwargs | {"constant_keys": statics}, ) outputs = pm.evaluate(params=backend_params, data=backend_data) @@ -268,7 +267,7 @@ def test_buffer_1(): def test_buffer_2(): model = Buffer() - model.set_types(input=Tensor) + model.set_differentiability(input=True) params = {"input": [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]} output_gradients = {"output": [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]]} reference_outputs = {"output": [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]} @@ -396,7 +395,7 @@ def test_isnan_2(): def test_nan_to_num_1(): - model = NanToNum() + model = NanToNum(input=Tensor(differentiable=True)) params = {"input": [[1.0, float("nan"), 3.0], [1.0, 2.0, float("nan")]]} output_gradients = {"output": [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]]} reference_outputs = {"output": [[1.0, 0.0, 3.0], [1.0, 2.0, 0.0]]} @@ -669,7 +668,9 @@ def test_conv1d_6(): def test_where_1(): - model = Where() + model = Where( + input1=Tensor(differentiable=True), input2=Tensor(differentiable=True) + ) data = { "cond": [True, True, False, False, True], @@ -700,7 +701,9 @@ def test_where_1(): def test_where_2(): - model = Where() + model = Where( + input1=Tensor(differentiable=True), input2=Tensor(differentiable=True) + ) data = { "cond": [True, True, False, False, True], @@ -1243,7 +1246,7 @@ def test_to_tensor(): def test_reduce_prod_1(): - model = Prod(axis=None) + model = Prod(input=Tensor(differentiable=True), axis=None) params = {"input": [1.0, 2.0, 3.0, 4.0, 5.0]} output_gradients = {"output": 1.0} reference_outputs = {"output": 120.0} @@ -1263,7 +1266,7 @@ def test_reduce_prod_1(): def test_reduce_prod_2(): - model = Prod(axis=None) + model = Prod(input=Tensor(differentiable=True), axis=None) params = {"input": [1.0, 0.0, 3.0, 4.0, 5.0]} output_gradients = {"output": 1.0} reference_outputs = {"output": 0.0} @@ -1283,7 +1286,7 @@ def test_reduce_prod_2(): def test_reduce_prod_3(): - model = Prod(axis=None) + model = Prod(input=Tensor(differentiable=True), axis=None) params = {"input": [1.0, 0.0, 3.0, 0.0, 5.0]} output_gradients = {"output": 12.0} reference_outputs = {"output": 0.0} @@ -1303,7 +1306,7 @@ def test_reduce_prod_3(): def test_reduce_prod_4(): - model = Prod(axis=1) + model = Prod(input=Tensor(differentiable=True), axis=1) params = {"input": [[1.0, 2.0], [3.0, 4.0]]} output_gradients = {"output": [2.0, 3.0]} reference_outputs = {"output": [2.0, 12.0]} @@ -1323,7 +1326,7 @@ def test_reduce_prod_4(): def test_reduce_prod_5(): - model = Prod(axis=(1, 2)) + model = Prod(input=Tensor(differentiable=True), axis=(1, 2)) params = { "input": [ [[1.0, 2.0], [3.0, 4.0]], @@ -1359,7 +1362,7 @@ def test_reduce_prod_5(): def test_reduce_prod_6(): - model = Prod(axis=(1, 2), keepdim=True) + model = Prod(input=Tensor(differentiable=True), axis=(1, 2), keepdim=True) params = { "input": [ [[1.0, 2.0], [3.0, 4.0]], @@ -1643,7 +1646,7 @@ def test_eye_complement_w_dtype(): def test_squeeze_1(): - model = Squeeze() + model = Squeeze(input=Tensor(differentiable=True)) params = {"input": list_full(1.0, 3, 1, 4, 2, 1)} output_gradients = {"output": list_full(1.0, 3, 4, 2)} reference_outputs = {"output": list_full(1.0, 3, 4, 2)} @@ -1663,7 +1666,7 @@ def test_squeeze_1(): def test_squeeze_2(): - model = Squeeze() + model = Squeeze(input=Tensor(differentiable=True)) params = {"input": list_full(1.0, 3, 1, 4, 2, 1, 1, 1, 5)} output_gradients = {"output": list_full(1.0, 3, 4, 2, 5)} reference_outputs = {"output": list_full(1.0, 3, 4, 2, 5)} @@ -1683,7 +1686,7 @@ def test_squeeze_2(): def test_broadcast_to_1(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": list_full(1.0, 1, 1)} output_gradients = {"output": list_full(1.0, 3, 3)} reference_outputs = {"output": list_full(1.0, 3, 3)} @@ -1704,7 +1707,7 @@ def test_broadcast_to_1(): def test_broadcast_to_2(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [4.0]} output_gradients = {"output": [[3.0, 4.0], [5.0, 6.0]]} reference_outputs = {"output": [[4.0, 4.0], [4.0, 4.0]]} @@ -1725,7 +1728,7 @@ def test_broadcast_to_2(): def test_broadcast_to_3(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [[1.0], [7.0]]} output_gradients = {"output": [[3.0, 4.0], [5.0, 6.0]]} reference_outputs = {"output": [[1.0, 1.0], [7.0, 7.0]]} @@ -1746,7 +1749,7 @@ def test_broadcast_to_3(): def test_broadcast_to_4(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [1.0]} output_gradients = {"output": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]} reference_outputs = {"output": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]} @@ -1767,7 +1770,7 @@ def test_broadcast_to_4(): def test_broadcast_to_5(): - model = BroadcastTo() + model = BroadcastTo(input=Tensor(differentiable=True)) params = {"input": [[1.0, 2.0], [3.0, 4.0]]} output_gradients = { "output": [ @@ -1804,7 +1807,7 @@ def test_broadcast_to_5(): def test_norm_modifier_1(): - model = NormModifier() + model = NormModifier(input=Tensor(differentiable=True)) params = {"input": 3.0} output_gradients = {"output": 2.0} reference_outputs = {"output": 3.0} @@ -1824,7 +1827,7 @@ def test_norm_modifier_1(): def test_norm_modifier_2(): - model = NormModifier() + model = NormModifier(input=Tensor(differentiable=True)) params = {"input": 6.0} output_gradients = {"output": 2.0} reference_outputs = {"output": 4.0} @@ -1844,7 +1847,7 @@ def test_norm_modifier_2(): def test_norm_modifier_3(): - model = NormModifier() + model = NormModifier(input=Tensor(differentiable=True)) params = {"input": -1.0} output_gradients = {"output": 2.0} reference_outputs = {"output": 3.0} @@ -1917,6 +1920,7 @@ def test_size_3(): def test_scaled_dot_product_1(): model = ScaledDotProduct() + model.set_differentiability(query=True, key=True, value=True) params = {"query": [[1.0]], "key": [[1.0]], "value": [[1.0]]} output_gradients = {"output": [[1.0]]} reference_outputs = {"output": [[1.0]]} @@ -1937,6 +1941,7 @@ def test_scaled_dot_product_1(): def test_scaled_dot_product_2(): model = ScaledDotProduct() + model.set_differentiability(query=True, key=True, value=True) params = { "query": [[1.0, 1.0], [1.0, 1.0]], "key": [[1.0, 1.0], [1.0, 1.0]], @@ -2085,7 +2090,7 @@ def test_slice_4(): def test_log_1(): - model = Log() + model = Log(input=Tensor(differentiable=True)) params = {"input": [3.0]} output_gradients = {"output": [1.0]} @@ -2108,7 +2113,7 @@ def test_log_1(): def test_log_2(): - model = Log() + model = Log(input=Tensor(differentiable=True)) params = {"input": [3.0, 2.0]} output_gradients = {"output": [1.0, 1.0]} diff --git a/tests/scripts/test_differentiablity.py b/tests/scripts/test_differentiablity.py index 431296b5..ada402c5 100644 --- a/tests/scripts/test_differentiablity.py +++ b/tests/scripts/test_differentiablity.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import mithril from mithril import JaxBackend from mithril.framework.common import Tensor @@ -96,7 +98,6 @@ def test_set_diff_data_and_param(): def test_match_tensor_with_value_data_and_param(): model1 = Multiply() model1.set_types(left=Tensor) - model1.set_differentiability(left=False) assert not model1.left.metadata.differentiable @@ -112,23 +113,44 @@ def test_match_tensor_with_value_data_and_param(): assert model.my_input.metadata.differentiable # type: ignore -def test_match_tensor_with_value_data_and_param_rev(): +def test_match_tensor_with_value_data_and_param_error(): + model1 = Multiply() + model1.set_types(left=Tensor) + model1.set_differentiability(left=False) + + assert not model1.left.metadata.differentiable + model2 = Multiply() model2.set_types(left=Tensor) model2.set_differentiability(left=True) assert model2.left.metadata.differentiable + model = Model() + model += model1(left="my_input") + with pytest.raises(ValueError) as err_info: + model += model2(left="my_input") + assert str(err_info.value) == "Differentiability mismatch!" + + +def test_match_tensor_with_value_data_and_param_error_rev(): model1 = Multiply() model1.set_types(left=Tensor) - model1.set_differentiability(left=False) + model1.set_differentiability(left=True) - assert not model1.left.metadata.differentiable + assert model1.left.metadata.differentiable + + model2 = Multiply() + model2.set_types(left=Tensor) + model2.set_differentiability(left=False) + + assert not model2.left.metadata.differentiable model = Model() model += model1(left="my_input") - model += model2(left="my_input") - assert model.my_input.metadata.differentiable # type: ignore + with pytest.raises(ValueError) as err_info: + model += model2(left="my_input") + assert str(err_info.value) == "Differentiability mismatch!" def test_diff_inference(): diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index f042846b..54f77c3c 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -169,6 +169,57 @@ def test_linear_not_expose(): ) +def test_set_cins_couts(): + model = Model() + linear_1 = Linear(dimension=42) + model |= linear_1(input="input", weight="weight", output=IOKey(name="output")) + model.set_cin("weight", linear_1.bias) + outer_model = Model() + linear_2 = Linear(dimension=42) + outer_model |= model(input="input_1", output=IOKey(name="output_1")) + outer_model |= linear_2(input="input_2", output=IOKey(name="output_2")) + outer_model.set_cout("output_2") + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert ( + model_dict_created["assigned_cins"] # type: ignore + == model_dict_recreated["assigned_cins"] # type: ignore + == [] + ) + assert ( + model_dict_created["assigned_couts"] # type: ignore + == model_dict_recreated["assigned_couts"] # type: ignore + == [("m_1", 4)] + ) + assert ( + model_dict_created["submodels"]["m_0"]["assigned_cins"] # type: ignore + == model_dict_recreated["submodels"]["m_0"]["assigned_cins"] # type: ignore + == [("self", 2), "weight"] + ) + assert ( + model_dict_created["submodels"]["m_0"]["assigned_couts"] # type: ignore + == model_dict_recreated["submodels"]["m_0"]["assigned_couts"] # type: ignore + == [] + ) + + assert_models_equal(outer_model, model_recreated) + + backend = JaxBackend(dtype=mithril.float64) + assert_evaluations_equal( + outer_model, + model_recreated, + backend, + static_keys={ + "input_1": backend.ones([4, 256]), + "input_2": backend.ones([4, 256]), + }, + ) + + def test_constant_key(): model = Model() model | Add()(left="input", right=Tensor(3), output=IOKey(name="output")) @@ -958,12 +1009,12 @@ def test_valued_scalar_in_init(): outer_model = Model() outer_model |= model() - model_dict_created = dict_conversions.model_to_dict(model) + model_dict_created = dict_conversions.model_to_dict(outer_model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated - assert_models_equal(model, model_recreated) + assert_models_equal(outer_model, model_recreated) def test_valued_scalar_in_extend(): @@ -973,12 +1024,12 @@ def test_valued_scalar_in_extend(): outer_model = Model() outer_model |= model() - model_dict_created = dict_conversions.model_to_dict(model) + model_dict_created = dict_conversions.model_to_dict(outer_model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated - assert_models_equal(model, model_recreated) + assert_models_equal(outer_model, model_recreated) def test_valued_scalar_iokey(): @@ -990,12 +1041,12 @@ def test_valued_scalar_iokey(): outer_model = Model() outer_model |= model(axis=IOKey(name="axis", value=1)) - model_dict_created = dict_conversions.model_to_dict(model) + model_dict_created = dict_conversions.model_to_dict(outer_model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated - assert_models_equal(model, model_recreated) + assert_models_equal(outer_model, model_recreated) def test_non_valued_scalar(): @@ -1005,9 +1056,187 @@ def test_non_valued_scalar(): outer_model = Model() outer_model |= model() + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + +def test_assigned_shapes(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + model.set_shapes(buff_input=[1, 2, ("V", ...)], mean_input=[("V", ...), 3, 4]) + + model_dict_created = dict_conversions.model_to_dict(model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(model, model_recreated) + + assert ( + model_dict_created.get("assigned_shapes") + == model_dict_recreated.get("assigned_shapes") + == [ + [(("m_0", 0), [1, 2, "V,..."]), (("m_1", 0), ["V,...", 3, 4])], + ] + ) + + +def test_assigned_types_1(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + model.set_types(mean_input=Tensor[int | float]) + model_dict_created = dict_conversions.model_to_dict(model) model_recreated = dict_conversions.dict_to_model(model_dict_created) model_dict_recreated = dict_conversions.model_to_dict(model_recreated) assert model_dict_created == model_dict_recreated assert_models_equal(model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [ + (("m_1", 0), "tensor"), + ] + ) + + +def test_assigned_types_2(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + model.set_types(mean_input=Tensor[int | float]) + + outer_model = Model() + outer_model |= model + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [] + ) + + assert ( + model_dict_created["submodels"]["m_0"].get("assigned_types") # type: ignore + == model_dict_recreated["submodels"]["m_0"].get("assigned_types") # type: ignore + == [ + ("mean_input", "tensor"), + ] + ) + + +def test_assigned_types_multiple_times(): + model = Model() + model |= Buffer()(input="buff_input", output=IOKey(name="buff_out")) + mean_model = Mean(axis=TBD) + model |= mean_model(input="mean_input", output=IOKey(name="mean_out")) + model.set_types(mean_input=Tensor[int | float]) + model.set_types({mean_model.input: Tensor[int | float]}) + + outer_model = Model() + outer_model |= model + + # Assert only one assignment made even thought set multiple + # times. + assert len(model.assigned_types) == 1 + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [] + ) + + assert ( + model_dict_created["submodels"]["m_0"].get("assigned_types") # type: ignore + == model_dict_recreated["submodels"]["m_0"].get("assigned_types") # type: ignore + == [ + ("mean_input", "tensor"), + ] + ) + + +def test_assigned_types_multiple_times_different_types(): + model = Model() + buff_model = Buffer() + model |= buff_model(input="buff_input", output=IOKey(name="buff_out")) + mean_model = Mean(axis=TBD) + model |= mean_model(input="mean_input", output=IOKey(name="mean_out")) + # Set types for buff_model 2 times with different types. + # Note that the last assignment will be used. + model.set_types(buff_input=Tensor[int | float] | int | float) + model.set_types({buff_model.input: int}) + + outer_model = Model() + outer_model |= model + + # Assert only one assignment made even thought set multiple + # times. + assert len(model.assigned_types) == 1 + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [] + ) + + assert ( + model_dict_created["submodels"]["m_0"].get("assigned_types") # type: ignore + == model_dict_recreated["submodels"]["m_0"].get("assigned_types") # type: ignore + == [ + ("buff_input", "int"), + ] + ) + + +def test_assigned_types_from_outermost_model(): + model = Model() + buff_model = Buffer() + model |= buff_model(input="buff_input", output=IOKey(name="buff_out")) + model |= Mean(axis=TBD)(input="mean_input", output=IOKey(name="mean_out")) + + outer_model = Model() + outer_model |= model + outer_model |= Buffer()(input="buff_input_2") + outer_model.set_types(buff_input_2=Tensor[int | float]) + outer_model.merge_connections(buff_model.input, "buff_input_2") + + model_dict_created = dict_conversions.model_to_dict(outer_model) + model_recreated = dict_conversions.dict_to_model(model_dict_created) + model_dict_recreated = dict_conversions.model_to_dict(model_recreated) + + assert model_dict_created == model_dict_recreated + assert_models_equal(outer_model, model_recreated) + + assert ( + model_dict_created.get("assigned_types") + == model_dict_recreated.get("assigned_types") + == [(("m_0", 0), "tensor")] + ) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 9db4999d..e26b2cc2 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -6973,7 +6973,7 @@ def test_extending_operator(): def test_extending_operator_model(): model1 = Buffer() - with pytest.raises(RuntimeError) as err: + with pytest.raises(AttributeError) as err: model1 += Buffer() - assert str(err.value) == "Primitive models cannot have submodels." + assert str(err.value) == "Model is frozen and can not be extended!"