diff --git a/examples/gpt/model.py b/examples/gpt/model.py index ec56d618..9888d988 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/model.py @@ -101,5 +101,5 @@ def create_gpt(bias, block_size, dims, num_heads, num_layers, vocab_size): gpt = Model() gpt += transformer(input="input") gpt += Linear(vocab_size, use_bias=False, name="lm_head")(output=IOKey("output")) - gpt.input.set_differentiable(False) # type: ignore + gpt.set_differentiability({gpt.input: False}) # type: ignore return gpt diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 0000069f..a027aecd 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -827,6 +827,7 @@ def __init__( value: TensorValueType | ToBeDetermined = TBD, type: _TensorTypes = int | float | bool, shape: ShapeNode | None = None, + differentiable: bool = False, ): if shape is None: # If shape is not provided, create a new shape with a Variadic root. @@ -834,6 +835,7 @@ def __init__( self.shape: ShapeNode = shape self.type: _TensorTypes = type self.referees: set[IOHyperEdge] = set() + self.differentiable = differentiable # Initialize value as TBD and then set if any value is provided. self.value: TensorValueType | ToBeDetermined = TBD if not isinstance(value, ToBeDetermined): @@ -860,6 +862,7 @@ def set_value(self, value: TensorValueType) -> Updates: raise ValueError( f"Value is set before as {self.value}. A value can not be reset." ) + updates = Updates() # Set value. if self.value is TBD: @@ -874,6 +877,8 @@ def set_value(self, value: TensorValueType) -> Updates: updates.add(edge, update_type=UpdateType.VALUE) updates.add(edge, update_type=UpdateType.SHAPE) self.value = val + + self.differentiable = False return updates def match(self, other: Tensor[int | float | bool]) -> Updates: @@ -888,9 +893,16 @@ def match(self, other: Tensor[int | float | bool]) -> Updates: ) assert not isinstance(valued.value, ToBeDetermined) updates |= non_valued.set_value(valued.value) + + self.differentiable = False + other.differentiable = False + else: + self.differentiable |= other.differentiable + # Transfer all referees of other to self and update all # Tensors in all edges of other with self. self.referees |= other.referees + for edge in other.referees: # TODO: Update here when we have list of tensors in an edge. edge._value = self @@ -927,7 +939,6 @@ def __init__( type: set() for type in UpdateType } self._temp_shape: ShapeRepr | None = None # set random repr - self.differentiable: bool = False self.interval: list[float | int] | None = interval # Initially set type and value as not determined yet. self._type = ToBeDetermined @@ -938,6 +949,13 @@ def __init__( if value is not TBD: self.set_value(value) + @property + def differentiable(self) -> bool: + if self.is_tensor: + assert isinstance(self._value, Tensor) + return self._value.differentiable + return False + @property def is_polymorphic(self) -> bool: # Returns if the edge is of polymorphic type or not. @@ -1020,6 +1038,13 @@ def _value_compatible( return self.value is TBD or self.value == _other_value return True + def set_differentiability(self, differentiable: bool) -> None: + if self.is_tensor: + assert isinstance(self._value, Tensor) + self._value.differentiable = differentiable + elif differentiable: + raise ValueError("Non-tensor edges cannot be differentiable.") + def set_type(self, typ: type[Tensor[int | float | bool]] | ScalarType) -> Updates: updates = Updates() if self._type != typ: @@ -1048,9 +1073,7 @@ def set_type(self, typ: type[Tensor[int | float | bool]] | ScalarType) -> Update # Add self as type update, set new type and update differentiability. updates.add(self, UpdateType.TYPE) self._type = new_type - self.differentiable = (self.value is TBD) and bool( - find_intersection_type(Tensor[float], self._type) - ) + return updates def set_value( @@ -1100,8 +1123,11 @@ def set_value( self._value = value updates.add(self, UpdateType.VALUE) updates.value_updates.add(self) - # Add self to updates. - self.differentiable = self.value is TBD + + # # Add self to updates. + if self.value != TBD: + self.set_differentiability(False) + return updates def match(self, other: IOHyperEdge) -> Updates: @@ -1130,11 +1156,6 @@ def match(self, other: IOHyperEdge) -> Updates: self.constraints[type] |= other.constraints[type] other.constraints[type] = set() - # Update differentiability. - if isinstance(self._value, Tensor) and self._value.value is TBD: - is_diff = self.differentiable | other.differentiable - # TODO: Is it required to set other as well? - self.differentiable = other.differentiable = is_diff return updates def add_constraint(self, constraint: Constraint) -> None: @@ -1161,6 +1182,7 @@ def __init__( | ScalarType | None = None, expose: bool | None = None, + differentiable: bool = False, interval: list[float | int] | None = None, connections: set[ConnectionData | str] | None = None, ) -> None: @@ -1176,6 +1198,9 @@ def __init__( # Convert to generic Tensor type if Tensor type is provided. type = Tensor[int | float | bool] + if differentiable: + type = Tensor[float] + self.name = name self.expose = expose if connections is None: @@ -1193,6 +1218,9 @@ def __init__( f"Got {shape}." ) + if differentiable and value is not TBD: + raise ValueError("Scalar values can not be set as differentiable.") + if value is not TBD and type is not None: value_type = find_type(value) if find_intersection_type(value_type, type) is None: @@ -1204,6 +1232,7 @@ def __init__( self.value_shape = shape self.type = type self.interval = interval + self.differentiable = differentiable @dataclass @@ -1226,14 +1255,16 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: return id(self) == id(other) - def set_differentiable(self, differentiable: bool = True) -> None: + def set_differentiability(self, differentiable: bool = True) -> Updates: + updates = Updates() # TODO: Move this method to Model class as set_shapes, set_types etc. if self.metadata.is_tensor: - self.metadata.differentiable = differentiable + self.metadata.set_differentiability(differentiable) elif differentiable: - if self.metadata.edge_type is not ToBeDetermined: - raise ValueError("Scalar data can not be set as differentiable.") - self.metadata.differentiable = differentiable + updates |= self.metadata.set_type(Tensor[float]) + self.metadata.set_differentiability(differentiable) + + return updates ShapesType = ( @@ -1377,10 +1408,9 @@ def get_type(self, key: ConnectionData) -> KeyType: raise ValueError("No matching key type found!") def get_non_diff_keys(self) -> set[str]: - return {key for key, conn in self.all.items() if conn.metadata.is_non_diff} - - def is_key_non_diff(self, key: str) -> bool: - return self.get_data(key).is_non_diff + return { + key for key, conn in self.all.items() if not conn.metadata.differentiable + } def get_connection(self, key: str) -> ConnectionData | None: internals = self._connection_dict[KeyType.INTERNAL] diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 23ded642..e235a26e 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -231,6 +231,7 @@ def _prepare_keys( type=connection.type, shape=connection.value_shape, value=connection.value, + differentiable=connection.differentiable, ) return _connection @@ -250,6 +251,7 @@ def _add_connection( d_map = self.dependency_map.local_output_dependency_map expose = given_connection.expose outer_key = given_connection.name + con_obj = None set_value: ( ToBeDetermined @@ -347,6 +349,9 @@ def _add_connection( if not isinstance(set_value, NullConnection): updates |= con_obj.metadata.set_value(set_value) + if given_connection.differentiable: + updates |= con_obj.set_differentiability(True) + # Check multi-write error for con_obj. self._check_multi_write(is_input, local_connection, con_obj) @@ -1072,6 +1077,27 @@ def _create_connection( self.conns.add(con) return con + def _set_differentiability( + self, config: dict[str | ConnectionData, bool], **kwargs: bool + ) -> None: + updates = Updates() + + for key, value in chain(config.items(), kwargs.items()): + if isinstance(key, str): + if key not in self.conns.all: + raise KeyError(f"Connection {key} is not found in the model.") + + conn_data = self.conns.all[key] + updates |= conn_data.set_differentiability(value) + elif isinstance(key, ConnectionData): + if key not in self.conns.all.values(): + raise KeyError(f"Connection {key} is not found in the model.") + + updates |= key.set_differentiability(value) + + model = self._get_outermost_parent() + model.constraint_solver(updates) + def _set_shapes( self, shapes: ShapesType, diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 2e9187af..e7cc76a2 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -390,9 +390,6 @@ def key(self) -> str: def metadata(self) -> IOHyperEdge: return self.data.metadata - def set_differentiable(self, differentiable: bool = True) -> None: - self.data.set_differentiable(differentiable) - def __hash__(self) -> int: return hash(id(self)) @@ -412,6 +409,7 @@ def __init__( | ScalarType | None = None, expose: bool | None = None, + differantiable: bool = False, interval: list[float | int] | None = None, connections: set[Connection | str] | None = None, ) -> None: @@ -428,6 +426,7 @@ def __init__( expose=expose, interval=interval, connections=_connections, + differentiable=differantiable, ) @@ -591,6 +590,7 @@ def _prepare_keys( type=connection.type, shape=connection.value_shape, value=connection.value, + differentiable=connection.differentiable, ) case _: _connection = connection # type: ignore @@ -725,6 +725,18 @@ def __or__(self, info: ExtendInfo | Model) -> Self: | Mapping[Connection, ShapeTemplateType] ) + def set_differentiability( + self, config: dict[str | Connection, bool] | None = None, **kwargs: bool + ) -> None: + if config is None: + config = {} + + _config: dict[str | ConnectionData, bool] = { + key.data if isinstance(key, Connection) else key: value + for key, value in config.items() + } + self._set_differentiability(_config, **kwargs) + def set_shapes( self, config: ShapeType | None = None, **kwargs: ShapeTemplateType ) -> None: diff --git a/mithril/framework/logical/operator.py b/mithril/framework/logical/operator.py index 3e679b1e..adfc4566 100644 --- a/mithril/framework/logical/operator.py +++ b/mithril/framework/logical/operator.py @@ -64,6 +64,7 @@ def __init__( tensor = Tensor( type=get_args(value.type)[0], shape=shapes[key].node, + differentiable=value.differentiable, ) edge = IOHyperEdge(value=tensor, interval=value.interval) data_set.add(edge) @@ -144,3 +145,10 @@ def extend( **kwargs: ConnectionDataType, ) -> None: raise NotImplementedError("Operators cannot be extended!") + + def infer_differentiability(self, *inputs: bool) -> bool: + # Function to infer differentiability of the operator + # based on the differentiability of its inputs + + # If any of the inputs are differentiable, the output is differentiable + return any(inputs) diff --git a/mithril/framework/logical/operators.py b/mithril/framework/logical/operators.py index ac59d7d7..34511843 100644 --- a/mithril/framework/logical/operators.py +++ b/mithril/framework/logical/operators.py @@ -431,7 +431,9 @@ def __init__( super().__init__( formula_key="floor_divide", name=name, - output=BaseKey(type=Tensor[int | float] | int | float), + output=BaseKey( + type=Tensor[int | float] | int | float, differentiable=False + ), numerator=BaseKey(value=numerator), denominator=BaseKey(value=denominator), ) @@ -464,6 +466,9 @@ def __init__( dependencies={bcast_constraint}, ) + def infer_differentiability(self, *inputs: bool) -> bool: + return False + class MatrixMultiplyOp(Operator): _model_name: str = "MatrixMultiply" diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 911d1d2b..1d5588f9 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -31,7 +31,17 @@ def __init__( name: str | None = None, ) -> None: super().__init__(name=name, enforce_jit=model._jittable) - self._extend(model, {k: IOKey(k, expose=True) for k in model.external_keys}) + self._extend( + model, + { + k: IOKey( + k, + expose=True, + differantiable=model.conns.all[k].metadata.differentiable, + ) + for k in model.external_keys + }, + ) @property def submodel(self) -> Operator: diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index f93465e2..b299af29 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -187,7 +187,7 @@ def __init__( if global_key in self._non_differentiable_keys: # TODO: Create an API for setting differentiability of a tensor. - physical_data.differentiable = False + physical_data.set_differentiability(False) elif global_key in self._trainable_tensor_inputs: # if physical_data.edge_type not in (Tensor, ToBeDetermined): if not ( @@ -204,7 +204,7 @@ def __init__( raise ValueError( f"Valued data can not be trainable: {global_key}" ) - physical_data.differentiable = True + physical_data.set_differentiability(True) model_data[key] = physical_data self.flat_graph.data_memo[id(logical_data)] = physical_data @@ -224,7 +224,7 @@ def __init__( output = Operator.output_key _data_dict: dict[str, IOHyperEdge] = {} - self._infer_differentiability(model_data) + self._infer_differentiability(p_model, model_data) for inner_key in p_model.external_keys: outer_key = mappings[inner_key] if outer_key not in self.data: @@ -414,21 +414,22 @@ def output_keys(self) -> list[str]: def input_keys(self) -> set[str]: return self._input_keys - def _infer_differentiability(self, model_data: dict[str, IOHyperEdge]) -> None: + def _infer_differentiability( + self, p_model: Operator, model_data: dict[str, IOHyperEdge] + ) -> None: # Infer output differentiability only for the models # that have a Tensor type output. output_key = Operator.output_key output_edge = model_data[output_key] + input_diffs = [ + value.differentiable + for key, value in model_data.items() + if key != output_key + ] + if output_edge.is_tensor: - # If any of the inputs are differentiable, then - # the output is also differentiable. - for key, value in model_data.items(): - if key != output_key and not value.is_non_diff: - output_edge.differentiable = True - return - # If all inputs are non-differentiable, then the output is also - # non-differentiable. - output_edge.differentiable = False + diff = p_model.infer_differentiability(*input_diffs) + output_edge.set_differentiability(diff) def randomize_params( self, diff --git a/mithril/models/models.py b/mithril/models/models.py index b621b48c..f98a24ca 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -196,7 +196,7 @@ def __init__( dilation=IOKey(name="dilation", value=dilation), output=IOKey(name="output"), ) - self.input.set_differentiable(False) + self.set_cin("input", safe=False) self._freeze() @@ -298,7 +298,6 @@ def __init__( dilation=dt_converter.output, output=IOKey(name="output"), ) - self.input.set_differentiable(False) self.set_cin("input", safe=False) self._freeze() @@ -385,16 +384,15 @@ def __init__( conv_connections: dict[str, ConnectionType] = { "output": IOKey(name="output"), "input": IOKey("input", value=input), - "weight": IOKey("weight", value=weight), + "weight": IOKey("weight", value=weight, differantiable=True), "stride": IOKey(name="stride", value=stride), "padding": p_converter.output, "dilation": IOKey(name="dilation", value=dilation), } if use_bias: - conv_connections["bias"] = "bias" + conv_connections["bias"] = IOKey("bias", differantiable=True) self |= PrimitiveConvolution1D(use_bias=use_bias)(**conv_connections) - self.input.set_differentiable(False) self.set_cin("input", safe=False) self._freeze() @@ -480,16 +478,15 @@ def __init__( conv_connections: dict[str, ConnectionType] = { "output": IOKey(name="output"), "input": IOKey("input", value=input), - "weight": IOKey("weight", value=weight), + "weight": IOKey("weight", value=weight, differantiable=True), "stride": st_converter.output, "padding": pt_converter.output, "dilation": dt_converter.output, } if use_bias: - conv_connections["bias"] = "bias" + conv_connections["bias"] = IOKey("bias", differantiable=True) self |= PrimitiveConvolution2D(use_bias=use_bias)(**conv_connections) - self.input.set_differentiable(False) self.set_cin("input", safe=False) self._freeze() @@ -541,9 +538,15 @@ def __init__( output = IOKey(name="output") input_key = IOKey(name="input", value=input) - weight_key = IOKey(name="weight", value=weight).transpose() + weight_key = IOKey(name="weight", value=weight, differantiable=True).transpose() + if use_bias: - bias_key = IOKey(name="bias", value=bias, type=Tensor[int | float | bool]) + bias_key = IOKey( + name="bias", + value=bias, + type=Tensor[int | float | bool], + differantiable=True, + ) self |= mult(left=input_key, right=weight_key) self |= Add()(left=mult.output, right=bias_key, output=output) shapes["bias"] = [dim] @@ -551,7 +554,6 @@ def __init__( self |= mult(left=input_key, right=weight_key, output=output) self._set_shapes(shapes) - self.input.set_differentiable(False) self.set_cin("input", safe=False) self._freeze() @@ -600,7 +602,6 @@ def __init__( right=IOKey(name="bias", value=bias), output=IOKey(name="output"), ) - self.input.set_differentiable(False) self.set_cin("input", safe=False) self._freeze() @@ -701,7 +702,6 @@ def __init__( self += Divide()(numerator=numerator.output, denominator=denominator.output) self._set_shapes({"input": ["B", "C", "d"]}) - self.input.set_differentiable(False) shapes: dict[str, ShapeTemplateType] = { "left": ["B", "C", "d"], @@ -713,7 +713,9 @@ def __init__( mult.set_types( left=Tensor[int | float | bool], right=Tensor[int | float | bool] ) - self += mult(left=self.cout, right=IOKey("weight", value=weight)) + self += mult( + left=self.cout, right=IOKey("weight", value=weight, differantiable=True) + ) mult._set_shapes(shapes) if use_bias: @@ -721,7 +723,9 @@ def __init__( add.set_types( left=Tensor[int | float | bool], right=Tensor[int | float | bool] ) - self += add(left=self.cout, right=IOKey("bias", value=bias)) + self += add( + left=self.cout, right=IOKey("bias", value=bias, differantiable=True) + ) add._set_shapes(shapes) # TODO: Remove below Buffer after required naming-related changes are done. self += Buffer()(input=self.cout, output=IOKey(name="output")) @@ -783,7 +787,6 @@ def __init__( self |= Reshape()(input=_input_key, shape=input_shape) self._set_shapes({"input": ["B", "C", "H", "W"]}) - self.input.set_differentiable(False) shapes: dict[str, ShapeTemplateType] = { "left": ["B", "C", "H", "W"], @@ -791,13 +794,17 @@ def __init__( } if use_scale: - weight_key = IOKey(name="weight", type=Tensor[float], value=weight) + weight_key = IOKey( + name="weight", type=Tensor[float], value=weight, differantiable=True + ) mult = Multiply() self |= mult(left=self.cout, right=weight_key) mult._set_shapes(shapes) if use_bias: - bias_key = IOKey(name="bias", type=Tensor[float], value=bias) + bias_key = IOKey( + name="bias", type=Tensor[float], value=bias, differantiable=True + ) add = Add() self |= add(left=self.cout, right=bias_key) add._set_shapes(shapes) @@ -1102,11 +1109,12 @@ def __init__( linear_model = Linear() # Get kernel inputs from given model. - kernel_input_args = { - key: IOKey(key, value=kwargs.get(key, TBD)) - for key in kernel.input_keys - if not kernel.conns.is_key_non_diff(key) - } + kernel_input_args = {} + for key in kernel.input_keys: + conn = kernel.conns.get_connection(key) + if conn and conn.metadata.is_tensor and not key.startswith("$"): + kernel_input_args[key] = IOKey(key, value=kwargs.get(key, TBD)) + (kernel_output_name,) = kernel.conns.output_keys # NOTE: Assumes single output! kernel_output_args = {kernel_output_name: IOKey(name="kernel")} @@ -1177,7 +1185,6 @@ def __init__( self += decision_model( input=linear_model.output, output=IOKey(name="decision_output") ) - self.input.set_differentiable(False) self.set_cout(linear_model.output) self._freeze() @@ -1229,7 +1236,6 @@ def __init__( input=linear_model.output, output=IOKey(name="probs_output") ) - self.input.set_differentiable(False) self.set_cout(linear_model.output) self._freeze() @@ -1401,18 +1407,24 @@ def __init__( self |= slice_2(stop=scalar_item.output) self += tensor_item_2(input="prev_hidden", index=slice_2.output) self += mult_model_1( - input=tensor_item_2.output, weight=IOKey("w_hh", value=w_hh) + input=tensor_item_2.output, + weight=IOKey("w_hh", value=w_hh, differantiable=True), + ) + self += mult_model_2( + input="input", weight=IOKey("w_ih", value=w_ih, differantiable=True) ) - self += mult_model_2(input="input", weight=IOKey("w_ih", value=w_ih)) self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2( - left=sum_model_1.output, right=IOKey("bias_h", value=bias_h) + left=sum_model_1.output, + right=IOKey("bias_h", value=bias_h, differantiable=True), ) self += Tanh()(input=sum_model_2.output, output=IOKey(name="hidden")) - self += mult_model_3(input="hidden", weight=IOKey("w_ho", value=w_ho)) + self += mult_model_3( + input="hidden", weight=IOKey("w_ho", value=w_ho, differantiable=True) + ) self += Add()( left=mult_model_3.output, - right=IOKey("bias_o", value=bias_o), + right=IOKey("bias_o", value=bias_o, differantiable=True), output=IOKey(name="output"), ) shapes: dict[str, ShapeTemplateType] = { @@ -1424,8 +1436,7 @@ def __init__( "bias_h": ["d_hid"], "bias_o": ["d_out"], } - self.input.set_differentiable(False) - self.prev_hidden.set_differentiable(False) + self._set_shapes(shapes) self.set_cin("input", safe=False) self.set_cout("output") @@ -1600,9 +1611,7 @@ def __init__( "hidden": ["N", 1, "d_hid"], "cell": ["N", 1, "d_hid"], } - self.input.set_differentiable(False) - self.prev_hidden.set_differentiable(False) - self.prev_cell.set_differentiable(False) + self._set_shapes(shapes) self.set_cin("input", safe=False) self.set_cout("output") @@ -1757,7 +1766,6 @@ def __init__( "bias_o": ["d_hid"], } - self.input.set_differentiable(False) self._set_shapes(shapes) self._freeze() @@ -2223,7 +2231,7 @@ def __init__( bias=IOKey("bias", value=bias), output=IOKey(name="output"), ) - self.input.set_differentiable(False) + self._freeze() def __call__( # type: ignore[override] @@ -2316,7 +2324,6 @@ def __init__( output=IOKey(name="output"), ) - self.distances.set_differentiable(False) self._set_shapes({"distances": ["N", "N"], "pred_distances": ["N", "N"]}) self._freeze() @@ -2405,7 +2412,6 @@ def __init__( input=kl_divergence_model.output, output=IOKey(name="output") ) - self.distances.set_differentiable(False) self._set_shapes({"distances": ["N", "N"], "pred_distances": ["N", "N"]}) self.set_cin("distances", safe=False) self.set_cout("output") @@ -2513,7 +2519,6 @@ def __init__( input=self.coords, output=IOKey(name="predicted_coords") ) - self.input.set_differentiable(False) self._freeze() # self._set_shapes(trace=False, # input = ["N", "M"], # NOTE: Here "M" denotes input dim or @@ -2757,7 +2762,7 @@ def __init__( "prediction": ["N", 1], "confidence": ["N", 1], } - self.label.set_differentiable(False) + self._set_shapes(shapes) self._freeze() @@ -2847,7 +2852,7 @@ def __init__( "alpha": ["N", 1], "output": [1], } - self.labels.set_differentiable(False) + self._set_shapes(shapes) self._freeze() @@ -2927,7 +2932,6 @@ def __init__( self += Buffer()(input=label_key, output=IOKey("label_formatted")) self += Buffer()(input=result, output=IOKey("output")) - self.label.set_differentiable(False) self.set_cin(self.pred) self._freeze() @@ -3149,7 +3153,6 @@ def __init__( self += Buffer()(input=precision, output=IOKey(name="output")) - self.label.set_differentiable(False) self.set_cin(self.pred) self._freeze() @@ -3309,7 +3312,6 @@ def __init__( self += Buffer()(input=recall, output=IOKey(name="output")) - self.label.set_differentiable(False) self.set_cin(self.pred) self._freeze() @@ -3474,7 +3476,6 @@ def __init__( self += Buffer()(input=precision, output=IOKey(name="output")) - self.label.set_differentiable(False) self.set_cin(self.pred) self._freeze() @@ -3536,7 +3537,6 @@ def __init__( self += Buffer()(auc_score, IOKey("output")) - self.label.set_differentiable(False) self.set_cin(self.pred) self._freeze() @@ -3573,7 +3573,6 @@ def __init__( ) self._set_shapes({"input": [("Var", ...)], "output": [("Var", ...)]}) - self.input.set_differentiable(False) self._freeze() def __call__( # type: ignore[override] diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index ef4e37a0..5b5efb0f 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -1121,7 +1121,6 @@ def __init__( self._add_constraint( fn=general_tensor_type_constraint, keys=[Operator.output_key, "input"] ) - self.indices.set_differentiable(False) def __call__( # type: ignore[override] self, @@ -2230,7 +2229,12 @@ def __init__( name=name, output=BaseKey(shape=[("N1", ...), "d1", out_dim], type=Tensor), input=BaseKey(shape=[("N1", ...), "d1"], type=Tensor[int], value=input), - weight=BaseKey(shape=[num_embeddings, out_dim], type=Tensor, value=weight), + weight=BaseKey( + shape=[num_embeddings, out_dim], + type=Tensor, + value=weight, + differentiable=True, + ), ) self._add_constraint( diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index ebf8e9dc..062b998e 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -174,11 +174,16 @@ def add_loss( **kwargs: Any, ) -> None: # If provided key namings does not match with Loss model + if { key for key, value in loss_model(**kwargs).connections.items() if value is NOT_GIVEN and key in loss_model.input_keys - } - loss_model.conns.get_non_diff_keys(): + } - { + conn.key + for conn in loss_model.conns.input_connections + if conn.metadata.is_scalar + }: # if set(kwargs.keys()) != keys: raise KeyError("The provided keys do not match the model's loss.") @@ -284,6 +289,7 @@ def add_loss( if (loss_con := self.conns.get_con_by_metadata(prev_out_key.metadata)) is None: raise KeyError("Given key does not belong to the Model!") + self.loss_keys[loss_key] = loss_con.key # TODO: maybe only add reduce_inputs if it is not empty @@ -298,7 +304,15 @@ def add_regularization( key_name: str | None = None, **kwargs: Any, ) -> None: - keys = set(model.input_keys) - model.conns.get_non_diff_keys() + valued_input_keys = { + key + for key, conn in zip( + model.conns.input_keys, model.conns.input_connections, strict=False + ) + if conn.metadata.is_valued + } + + keys = set(model.input_keys) - valued_input_keys if set(kwargs.keys()) != keys: raise KeyError( "The provided keys do not match the regularization model keys!" @@ -342,6 +356,7 @@ def _add_regularization( reg_str = reg_key.data.key case None: reg_str = model.cin.key + if any([isinstance(value, re.Pattern) for value in kwargs.values()]): if len(kwargs) > 1: raise Exception( @@ -398,6 +413,7 @@ def _add_regularization( keywords[key] = value.data else: keywords[key] = value + self.extend(model, **keywords) if isinstance(outer_key := kwargs[reg_str], ConnectionData): outer_key = outer_key.key @@ -463,10 +479,9 @@ def _add_loss_combiner(self) -> None: concat_model = Concat(n=num_of_loss_keys, axis=None) concat_kwargs: dict[Any, Any] = {} idx = 0 - for key in concat_model.input_keys: - # if not concat_model.connections[key].metadata.value.is_non_diff: - if not concat_model.conns.is_key_non_diff(key): - concat_kwargs[key] = self.conns.all[ + for conn in concat_model.conns.input_connections: + if conn.metadata.is_tensor: + concat_kwargs[conn.key] = self.conns.all[ list(self.loss_keys.values())[idx] ] idx += 1 @@ -721,9 +736,9 @@ def _add_geo_mean(self) -> None: concat_model = Concat(n=n_final_outputs, axis=None) concat_kwargs: dict[str, Tensor[int] | ConnectionData] = {} idx = 0 - for key in concat_model.input_keys: - if not concat_model.conns.is_key_non_diff(key): - concat_kwargs[key] = final_outputs[idx] + for conn in concat_model.conns.input_connections: + if conn.metadata.is_tensor: + concat_kwargs[conn.key] = final_outputs[idx] idx += 1 self.extend(concat_model, **concat_kwargs) @@ -765,9 +780,9 @@ def _add_reduce_sizes( concat_model = Concat(n=num_of_sizes, axis=None) concat_kwargs: dict[str, int | ConnectionData] = {} idx = 0 - for key in concat_model.input_keys: - if not concat_model.conns.is_key_non_diff(key): - concat_kwargs[key] = sizes[idx] + for conn in concat_model.conns.input_connections: + if conn.metadata.is_tensor: + concat_kwargs[conn.key] = sizes[idx] idx += 1 self.extend(concat_model, **concat_kwargs) self.extend(prod := Prod(), input=concat_model.output.data) diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index b0c3aa3e..2e295aec 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -215,8 +215,6 @@ def dict_to_model( # TODO: Get rid of using eval method. Find more secure # way to convert strings into types and generic types. set_types[key] = eval(typ) - if set_types: - model.set_types(set_types) # type: ignore unnamed_keys: list[str] = params.get("unnamed_keys", []) differentiability_info: dict[str, bool] = params.get("differentiability_info", {}) @@ -249,6 +247,9 @@ def dict_to_model( model |= m(**mappings) + if set_types: + model.set_types(set_types) + if "model" in canonical_keys: if canonical_keys["model"][0] is not None: model.set_cin(*canonical_keys["model"][0]) @@ -258,7 +259,7 @@ def dict_to_model( for key, value in differentiability_info.items(): con = model.conns.get_connection(key) assert con is not None - con.set_differentiable(value) + con.set_differentiability(value) if len(assigned_constraints) > 0: for constr_info in assigned_constraints: @@ -284,14 +285,14 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: model_name = model.__class__.__name__ args = handle_model_to_dict_args(model_name, model.factory_args) assigned_shapes: dict[str, ShapeTemplateType] = {} - differentiablility_info: dict[str, bool] = {} + differentiability_info: dict[str, bool] = {} assigned_constraints: list[AssignedConstraintType] = [] types: dict[str, str] = {} for key, con in model.conns.all.items(): edge = con.metadata if edge.is_tensor and not con.is_key_autogenerated: - differentiablility_info[key] = edge.differentiable + differentiability_info[key] = edge.differentiable for shape in model.assigned_shapes: assigned_shapes |= shape_to_dict(shape) @@ -314,7 +315,7 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: "name": model_name, "args": args, "assigned_shapes": assigned_shapes, - "differentiability_info": differentiablility_info, + "differentiability_info": differentiability_info, "assigned_constraints": assigned_constraints, "types": types, } @@ -351,7 +352,7 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: "name": model_name, "args": args, "assigned_shapes": assigned_shapes, - "differentiability_info": differentiablility_info, + "differentiability_info": differentiability_info, "assigned_constraints": assigned_constraints, "types": types, "connections": connection_dict, @@ -377,7 +378,9 @@ def connection_to_dict( for key, connection in connections.items(): key_value: ConnectionDict | None | str | AllValueType = None related_conn = submodel_connections.get(connection.metadata, []) - is_valued = connection.metadata.is_non_diff and connection.metadata.value != TBD + is_valued = ( + not connection.metadata.differentiable and connection.metadata.value != TBD + ) # Connection is defined and belong to another model if related_conn and model_id not in related_conn: key_value = {} diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index 3e524ac2..b845b7a9 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -1106,6 +1106,9 @@ "model": { "name": "LeakyRelu" }, + "static_input_info": { + "slope": 0.01 + }, "input_info": { "input" : { "shapes": [[1,10], [1, 10], [1, 10], [1, 10], [1, 10]] diff --git a/tests/json_files/static_keys_directed_test.json b/tests/json_files/static_keys_directed_test.json index a41786df..cad1ce40 100644 --- a/tests/json_files/static_keys_directed_test.json +++ b/tests/json_files/static_keys_directed_test.json @@ -63,7 +63,7 @@ }, "results": { - "static_keys": ["output2"] + "static_keys": ["input2", "output2"] } }, @@ -175,7 +175,7 @@ }, "results": { - "static_keys": ["output1", "output2", "output4"] + "static_keys": ["input2", "output1", "output2", "output4"] } }, "test_composite_3": { @@ -287,7 +287,7 @@ }, "results": { - "static_keys": ["output3"] + "static_keys": ["input1", "output3"] } }, "test_composite_4": { @@ -498,7 +498,7 @@ }, "results": { - "static_keys": [] + "static_keys": ["input1", "input2"] } }, "test_composite_6": { @@ -577,7 +577,7 @@ }, "results": { - "static_keys": ["output1"] + "static_keys": ["input2", "input3", "output1"] } }, "test_composite_7": { @@ -662,7 +662,7 @@ }, "results": { - "static_keys": ["input2", "output1", "output3"] + "static_keys": ["input2", "input3", "output1", "output3"] } }, "test_composite_8": { @@ -743,7 +743,7 @@ }, "results": { - "static_keys": ["output2"] + "static_keys": ["input1", "output2"] } } } \ No newline at end of file diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 3e961693..f422f58c 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -74,6 +74,7 @@ def evaluate_case( backend=backend, constant_keys=static_keys, discard_keys=discard_keys, + trainable_keys=reference_gradients, jit=False, safe_shapes=True, ) diff --git a/tests/scripts/summary_txts/test_physical_model_summary_5 b/tests/scripts/summary_txts/test_physical_model_summary_5 index e85f3958..5130928c 100644 --- a/tests/scripts/summary_txts/test_physical_model_summary_5 +++ b/tests/scripts/summary_txts/test_physical_model_summary_5 @@ -1,4 +1,3 @@ - Model Info ======================================================= Backend type : jax @@ -11,11 +10,11 @@ Output keys : output ------------------------------------------------------- Constant inputs : None ------------------------------------------------------- -Static keys : None +Static keys : denominator, exponent, left, right ------------------------------------------------------- -Trainable keys : denominator, exponent, left, right +Trainable keys : None ------------------------------------------------------- -Total Parameters : >3 +Total Parameters : 0 ------------------------------------------------------- @@ -25,18 +24,18 @@ Model Name | Model Keys | ------------------------------------------------------ | Keys : Shapes : Connections : Parameters ================================================================================= -Add | Inputs : left : [u1, u2] : 'left' : Unknown - | right : [ 1] : 'right' : 1 +Add | Inputs : left : [u1, u2] : 'left' : 0 + | right : [ 1] : 'right' : 0 | ---------------------------------------------------------------- | Outputs : output : [u1, u2] : Divide.numerator : 0 --------------------------------------------------------------------------------- Divide | Inputs : numerator : [u1, u2] : Add.output : 0 - | denominator : [ 1] : 'denominator' : 1 + | denominator : [ 1] : 'denominator' : 0 | ---------------------------------------------------------------- | Outputs : output : [u1, u2] : Power.base : 0 --------------------------------------------------------------------------------- Power | Inputs : base : [u1, u2] : Divide.output : 0 - | exponent : [ 1] : 'exponent' : 1 + | exponent : [ 1] : 'exponent' : 0 | ---------------------------------------------------------------- | Outputs : output : [u1, u2] : 'output' : 0 --------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_model_summary_6 b/tests/scripts/summary_txts/test_physical_model_summary_6 index 9816ee23..38fd8965 100644 --- a/tests/scripts/summary_txts/test_physical_model_summary_6 +++ b/tests/scripts/summary_txts/test_physical_model_summary_6 @@ -1,21 +1,21 @@ - Model Info -============================================== -Backend type : jax ----------------------------------------------- -Backend precision : 32 ----------------------------------------------- -Backend device : TFRT_CPU_0 ----------------------------------------------- -Output keys : output ----------------------------------------------- -Constant inputs : None ----------------------------------------------- -Static keys : None ----------------------------------------------- -Trainable keys : b, input1, input2, weight ----------------------------------------------- -Total Parameters : >0 ----------------------------------------------- + Model Info +=================================== +Backend type : jax +----------------------------------- +Backend precision : 32 +----------------------------------- +Backend device : TFRT_CPU_0 +----------------------------------- +Output keys : output +----------------------------------- +Constant inputs : None +----------------------------------- +Static keys : input1, input2 +----------------------------------- +Trainable keys : b, weight +----------------------------------- +Total Parameters : >0 +----------------------------------- PhysicalModel @@ -24,8 +24,8 @@ Model Name | Model Keys | ------------------------------------------------------ | Keys : Shapes : Connections : Parameters ===================================================================================== -Add_0 | Inputs : left : [u1, u2] : 'input1' : Unknown - | right : [u1, u2] : 'input2' : Unknown +Add_0 | Inputs : left : [u1, u2] : 'input1' : 0 + | right : [u1, u2] : 'input2' : 0 | ---------------------------------------------------------------- | Outputs : output : [u1, u2] : Relu.input : 0 ------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_model_summary_7 b/tests/scripts/summary_txts/test_physical_model_summary_7 index 9816ee23..38fd8965 100644 --- a/tests/scripts/summary_txts/test_physical_model_summary_7 +++ b/tests/scripts/summary_txts/test_physical_model_summary_7 @@ -1,21 +1,21 @@ - Model Info -============================================== -Backend type : jax ----------------------------------------------- -Backend precision : 32 ----------------------------------------------- -Backend device : TFRT_CPU_0 ----------------------------------------------- -Output keys : output ----------------------------------------------- -Constant inputs : None ----------------------------------------------- -Static keys : None ----------------------------------------------- -Trainable keys : b, input1, input2, weight ----------------------------------------------- -Total Parameters : >0 ----------------------------------------------- + Model Info +=================================== +Backend type : jax +----------------------------------- +Backend precision : 32 +----------------------------------- +Backend device : TFRT_CPU_0 +----------------------------------- +Output keys : output +----------------------------------- +Constant inputs : None +----------------------------------- +Static keys : input1, input2 +----------------------------------- +Trainable keys : b, weight +----------------------------------- +Total Parameters : >0 +----------------------------------- PhysicalModel @@ -24,8 +24,8 @@ Model Name | Model Keys | ------------------------------------------------------ | Keys : Shapes : Connections : Parameters ===================================================================================== -Add_0 | Inputs : left : [u1, u2] : 'input1' : Unknown - | right : [u1, u2] : 'input2' : Unknown +Add_0 | Inputs : left : [u1, u2] : 'input1' : 0 + | right : [u1, u2] : 'input2' : 0 | ---------------------------------------------------------------- | Outputs : output : [u1, u2] : Relu.input : 0 ------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_model_summary_8 b/tests/scripts/summary_txts/test_physical_model_summary_8 index 34ff7e72..1e044166 100644 --- a/tests/scripts/summary_txts/test_physical_model_summary_8 +++ b/tests/scripts/summary_txts/test_physical_model_summary_8 @@ -1,4 +1,3 @@ - Model Info =============================================== Backend type : jax @@ -11,11 +10,11 @@ Output keys : output ----------------------------------------------- Constant inputs : None ----------------------------------------------- -Static keys : None +Static keys : input1, input2_0, input2_1 ----------------------------------------------- -Trainable keys : input1, input2_0, input2_1 +Trainable keys : None ----------------------------------------------- -Total Parameters : >0 +Total Parameters : 0 ----------------------------------------------- @@ -25,8 +24,8 @@ Model Name | Model Keys | ------------------------------------ | Keys : Connections : Parameters =============================================================== -Add_0 | Inputs : left : 'input1' : Unknown - | right : 'input2_0' : Unknown +Add_0 | Inputs : left : 'input1' : 0 + | right : 'input2_0' : 0 | ---------------------------------------------- | Outputs : output : Relu.input : 0 --------------------------------------------------------------- @@ -39,7 +38,7 @@ Sigmoid | Inputs : input : Relu.output : 0 | Outputs : output : Add_1.left : 0 --------------------------------------------------------------- Add_1 | Inputs : left : Sigmoid.output : 0 - | right : 'input2_1' : Unknown + | right : 'input2_1' : 0 | ---------------------------------------------- | Outputs : output : 'output' : 0 --------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_model_summary_9 b/tests/scripts/summary_txts/test_physical_model_summary_9 index c533a205..7dc051a1 100644 --- a/tests/scripts/summary_txts/test_physical_model_summary_9 +++ b/tests/scripts/summary_txts/test_physical_model_summary_9 @@ -1,4 +1,3 @@ - Model Info =============================== Backend type : jax @@ -11,11 +10,11 @@ Output keys : output ------------------------------- Constant inputs : None ------------------------------- -Static keys : None +Static keys : input1 ------------------------------- -Trainable keys : input1 +Trainable keys : None ------------------------------- -Total Parameters : >0 +Total Parameters : 0 ------------------------------- @@ -25,7 +24,7 @@ Model Name | Model Keys | ----------------------------------------------- | Keys : Shapes : Connections : Parameters ========================================================================== -Relu_0 | Inputs : input : [u1, u2] : 'input1' : Unknown +Relu_0 | Inputs : input : [u1, u2] : 'input1' : 0 | --------------------------------------------------------- | Outputs : output : [u1, u2] : Relu_1.input : 0 -------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_10 b/tests/scripts/summary_txts/test_physical_summary_10 index 6cf41693..9cd4694b 100644 --- a/tests/scripts/summary_txts/test_physical_summary_10 +++ b/tests/scripts/summary_txts/test_physical_summary_10 @@ -10,11 +10,11 @@ Output keys : output1 ------------------------------- Constant inputs : None ------------------------------- -Static keys : None +Static keys : input ------------------------------- -Trainable keys : input +Trainable keys : None ------------------------------- -Total Parameters : >0 +Total Parameters : 0 ------------------------------- @@ -24,7 +24,7 @@ Model Name | Model Keys | -------------------------------------------------------------------- | Keys : Shapes : Types : Connections : Parameters =============================================================================================== -Sigmoid | Inputs : input : [(V1, ...)] : bool | float | int : 'input' : Unknown +Sigmoid | Inputs : input : [(V1, ...)] : bool | float | int : 'input' : 0 | ------------------------------------------------------------------------------ | Outputs : output : [(V1, ...)] : float : 'output1' : 0 ----------------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_11 b/tests/scripts/summary_txts/test_physical_summary_11 index 7986b822..9b7e03e7 100644 --- a/tests/scripts/summary_txts/test_physical_summary_11 +++ b/tests/scripts/summary_txts/test_physical_summary_11 @@ -10,11 +10,11 @@ Output keys : output1 ------------------------------- Constant inputs : None ------------------------------- -Static keys : None +Static keys : input ------------------------------- -Trainable keys : input +Trainable keys : None ------------------------------- -Total Parameters : >0 +Total Parameters : 0 ------------------------------- @@ -24,7 +24,7 @@ Model Name | Model Keys | ----------------------------------------------- | Keys : Shapes : Connections : Parameters ========================================================================== -Sigmoid | Inputs : input : [(V1, ...)] : 'input' : Unknown +Sigmoid | Inputs : input : [(V1, ...)] : 'input' : 0 | --------------------------------------------------------- | Outputs : output : [(V1, ...)] : 'output1' : 0 -------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_12 b/tests/scripts/summary_txts/test_physical_summary_12 index 1138666f..d591deaa 100644 --- a/tests/scripts/summary_txts/test_physical_summary_12 +++ b/tests/scripts/summary_txts/test_physical_summary_12 @@ -1,4 +1,3 @@ - Model Info ===================================== Backend type : jax @@ -11,11 +10,11 @@ Output keys : output1, output2 ------------------------------------- Constant inputs : None ------------------------------------- -Static keys : None +Static keys : input ------------------------------------- -Trainable keys : input +Trainable keys : None ------------------------------------- -Total Parameters : >0 +Total Parameters : 0 ------------------------------------- @@ -25,7 +24,7 @@ Model Name | Model Keys | ----------------------------------------------- | Keys : Shapes : Connections : Parameters ========================================================================== -Sigmoid | Inputs : input : [(V1, ...)] : 'input' : Unknown +Sigmoid | Inputs : input : [(V1, ...)] : 'input' : 0 | --------------------------------------------------------- | Outputs : output : [(V1, ...)] : 'output1' : 0 | 'output2' diff --git a/tests/scripts/summary_txts/test_physical_summary_14 b/tests/scripts/summary_txts/test_physical_summary_14 index 2327410a..9076d08c 100644 --- a/tests/scripts/summary_txts/test_physical_summary_14 +++ b/tests/scripts/summary_txts/test_physical_summary_14 @@ -10,11 +10,11 @@ Output keys : output1 -------------------------------- Constant inputs : None -------------------------------- -Static keys : None +Static keys : left, right -------------------------------- -Trainable keys : left, right +Trainable keys : None -------------------------------- -Total Parameters : >60 +Total Parameters : 0 -------------------------------- @@ -24,8 +24,8 @@ Model Name | Model Keys | ---------------------------------------------------- | Keys : Shapes : Connections : Parameters =============================================================================== -Add | Inputs : left : [3, 4, 5] : 'left' : 60 - | right : [...] : 'right' : Unknown +Add | Inputs : left : [3, 4, 5] : 'left' : 0 + | right : [...] : 'right' : 0 | -------------------------------------------------------------- | Outputs : output : [..., 3, 4, 5] : 'output1' : 0 ------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_16 b/tests/scripts/summary_txts/test_physical_summary_16 index 013c8f1f..0aba2c34 100644 --- a/tests/scripts/summary_txts/test_physical_summary_16 +++ b/tests/scripts/summary_txts/test_physical_summary_16 @@ -18,16 +18,16 @@ Total Parameters : 3 ------------------------------- - Add --------------------------------------------------------------------------------------------------- -Model Name | Model Keys - | ----------------------------------------------------------------------- - | Keys : Shapes : Types : Connections : Parameters -================================================================================================== -Add | Inputs : left : [None, ..., 3] : bool | float | int : 'output_1' : 0 - | right : [3] : bool | float | int : 'b' : 3 - | --------------------------------------------------------------------------------- - | Outputs : output : [None, ..., 3] : bool | float | int : 'output1' : 0 --------------------------------------------------------------------------------------------------- + Add +------------------------------------------------------------------------------------- +Model Name | Model Keys + | ---------------------------------------------------------- + | Keys : Shapes : Types : Connections : Parameters +===================================================================================== +Add | Inputs : left : [None, ..., 3] : float : 'output_1' : 0 + | right : [3] : float : 'b' : 3 + | -------------------------------------------------------------------- + | Outputs : output : [None, ..., 3] : float : 'output1' : 0 +------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_17 b/tests/scripts/summary_txts/test_physical_summary_17 index 9f4e8deb..78c25e8e 100644 --- a/tests/scripts/summary_txts/test_physical_summary_17 +++ b/tests/scripts/summary_txts/test_physical_summary_17 @@ -25,9 +25,9 @@ Model Name | Model Keys | Keys : Shapes : Types : Connections : Parameters ========================================================================================================== MatrixMultiply | Inputs : left : [None, ..., None] : bool | float | int : 'input' : Unknown - | right : [None, 3] : bool | float | int : 'output_0' : 0 + | right : [None, 3] : float : 'output_0' : 0 | ------------------------------------------------------------------------------------- - | Outputs : output : [None, ..., 3] : bool | float | int : 'output_1' : 0 + | Outputs : output : [None, ..., 3] : float : 'output_1' : 0 ---------------------------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_3 b/tests/scripts/summary_txts/test_physical_summary_3 index 68f52b4c..ed34f6ce 100644 --- a/tests/scripts/summary_txts/test_physical_summary_3 +++ b/tests/scripts/summary_txts/test_physical_summary_3 @@ -1,22 +1,21 @@ - Model Info -========================================================================================================================== -Backend type : jax --------------------------------------------------------------------------------------------------------------------------- -Backend precision : 32 --------------------------------------------------------------------------------------------------------------------------- -Backend device : TFRT_CPU_0 --------------------------------------------------------------------------------------------------------------------------- -Output keys : output --------------------------------------------------------------------------------------------------------------------------- -Constant inputs : None --------------------------------------------------------------------------------------------------------------------------- -Static keys : None --------------------------------------------------------------------------------------------------------------------------- -Trainable keys : bias, bias0, bias1, bias2, bias3, input1, input2, l_scale, sigma, weight, weight0, weight1, weight2, - weight3 --------------------------------------------------------------------------------------------------------------------------- -Total Parameters : >86 --------------------------------------------------------------------------------------------------------------------------- + Model Info +================================================================================================================= +Backend type : jax +----------------------------------------------------------------------------------------------------------------- +Backend precision : 32 +----------------------------------------------------------------------------------------------------------------- +Backend device : TFRT_CPU_0 +----------------------------------------------------------------------------------------------------------------- +Output keys : output +----------------------------------------------------------------------------------------------------------------- +Constant inputs : None +----------------------------------------------------------------------------------------------------------------- +Static keys : l_scale, sigma +----------------------------------------------------------------------------------------------------------------- +Trainable keys : bias, bias0, bias1, bias2, bias3, input1, input2, weight, weight0, weight1, weight2, weight3 +----------------------------------------------------------------------------------------------------------------- +Total Parameters : >84 +----------------------------------------------------------------------------------------------------------------- PhysicalModel @@ -45,7 +44,7 @@ Multiply_0 | Inputs : left : [u1, u3] : Sum.output | ------------------------------------------------------------------------------ | Outputs : output : [u1, u3] : Divide.numerator : 0 -------------------------------------------------------------------------------------------------------- -Square_1 | Inputs : input : [ 1] : 'sigma' : 1 +Square_1 | Inputs : input : [ 1] : 'sigma' : 0 | ------------------------------------------------------------------------------ | Outputs : output : [ 1] : Divide.denominator : 0 -------------------------------------------------------------------------------------------------------- @@ -58,8 +57,8 @@ Exponential | Inputs : input : [u1, u3] : Divide.output | ------------------------------------------------------------------------------ | Outputs : output : [u1, u3] : Multiply_2.right : 0 -------------------------------------------------------------------------------------------------------- -Multiply_1 | Inputs : left : [ 1] : 'l_scale' : 1 - | right : [ 1] : 'l_scale' : 1 +Multiply_1 | Inputs : left : [ 1] : 'l_scale' : 0 + | right : [ 1] : 'l_scale' : 0 | ------------------------------------------------------------------------------ | Outputs : output : [ 1] : Multiply_2.left : 0 -------------------------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth b/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth index d7540180..b5a14a3c 100644 --- a/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth +++ b/tests/scripts/summary_txts/test_physical_summary_3_logical_with_depth @@ -8,21 +8,21 @@ KernelizedSVM | Inputs : input1 : [u1, u2] : bool | float | int : '$input' | input2 : [u3, u2] : bool | float | int : '$input2' | sigma : [ 1] : bool | float | int : '$sigma' | l_scale : [ 1] : bool | float | int : '$l_scale' - | weight : [ 1, u3] : bool | float | int : '$weight' - | bias : [ 1] : bool | float | int : '$bias' + | weight : [ 1, u3] : float : '$weight' + | bias : [ 1] : float : '$bias' | ------------------------------------------------------------------------ | Outputs : kernel : [u1, u3] : float : -- | output : [u1, 1] : float : MLP.input -------------------------------------------------------------------------------------------- -MLP | Inputs : weight0 : [ 3, 1] : bool | float | int : '$weight0' +MLP | Inputs : weight0 : [ 3, 1] : float : '$weight0' | input : [u1, 1] : float : KernelizedSVM.output - | bias0 : [ 3] : bool | float | int : '$bias0' - | weight1 : [ 4, 3] : bool | float | int : '$weight1' - | bias1 : [ 4] : bool | float | int : '$bias1' - | weight2 : [ 5, 4] : bool | float | int : '$weight2' - | bias2 : [ 5] : bool | float | int : '$bias2' - | weight3 : [ 6, 5] : bool | float | int : '$weight3' - | bias3 : [ 6] : bool | float | int : '$bias3' + | bias0 : [ 3] : float : '$bias0' + | weight1 : [ 4, 3] : float : '$weight1' + | bias1 : [ 4] : float : '$bias1' + | weight2 : [ 5, 4] : float : '$weight2' + | bias2 : [ 5] : float : '$bias2' + | weight3 : [ 6, 5] : float : '$weight3' + | bias3 : [ 6] : float : '$bias3' | ------------------------------------------------------------------------ | Outputs : output : [u1, 6] : float : '$output' -------------------------------------------------------------------------------------------- @@ -43,45 +43,45 @@ RBFKernel | Inputs : input1 : [u1, u2] : bool | float | int : 'input1' | Outputs : output : [u1, u3] : bool | float | int : Linear.input | 'kernel' ------------------------------------------------------------------------------------- -Linear | Inputs : weight : [ 1, u3] : bool | float | int : 'weight' +Linear | Inputs : weight : [ 1, u3] : float : 'weight' | $axes : -- : NoneType : None | input : [u1, u3] : float : RBFKernel.output - | bias : [ 1] : bool | float | int : 'bias' + | bias : [ 1] : float : 'bias' | -------------------------------------------------------------------- - | Outputs : output : [u1, 1] : bool | float | int : 'output' + | Outputs : output : [u1, 1] : float : 'output' ------------------------------------------------------------------------------------- - MLP ---------------------------------------------------------------------------------- -Model Name | Model Keys - | ------------------------------------------------------ - | Keys : Shapes : Types : Connections -================================================================================= -Layer_0 | Inputs : weight : [ 3, 1] : bool | float | int : 'weight0' - | input : [u1, 1] : float : 'input' - | bias : [3] : bool | float | int : 'bias0' - | ---------------------------------------------------------------- - | Outputs : output : [u1, 3] : float : Layer_1.input ---------------------------------------------------------------------------------- -Layer_1 | Inputs : weight : [ 4, 3] : bool | float | int : 'weight1' - | input : [u1, 3] : float : Layer_0.output - | bias : [4] : bool | float | int : 'bias1' - | ---------------------------------------------------------------- - | Outputs : output : [u1, 4] : float : Layer_2.input ---------------------------------------------------------------------------------- -Layer_2 | Inputs : weight : [ 5, 4] : bool | float | int : 'weight2' - | input : [u1, 4] : float : Layer_1.output - | bias : [5] : bool | float | int : 'bias2' - | ---------------------------------------------------------------- - | Outputs : output : [u1, 5] : float : Layer_3.input ---------------------------------------------------------------------------------- -Layer_3 | Inputs : weight : [ 6, 5] : bool | float | int : 'weight3' - | input : [u1, 5] : float : Layer_2.output - | bias : [6] : bool | float | int : 'bias3' - | $slope : -- : float : 0.01 - | ---------------------------------------------------------------- - | Outputs : output : [u1, 6] : float : 'output' ---------------------------------------------------------------------------------- + MLP +----------------------------------------------------------------------- +Model Name | Model Keys + | -------------------------------------------- + | Keys : Shapes : Types : Connections +======================================================================= +Layer_0 | Inputs : weight : [ 3, 1] : float : 'weight0' + | input : [u1, 1] : float : 'input' + | bias : [3] : float : 'bias0' + | ------------------------------------------------------ + | Outputs : output : [u1, 3] : float : Layer_1.input +----------------------------------------------------------------------- +Layer_1 | Inputs : weight : [ 4, 3] : float : 'weight1' + | input : [u1, 3] : float : Layer_0.output + | bias : [4] : float : 'bias1' + | ------------------------------------------------------ + | Outputs : output : [u1, 4] : float : Layer_2.input +----------------------------------------------------------------------- +Layer_2 | Inputs : weight : [ 5, 4] : float : 'weight2' + | input : [u1, 4] : float : Layer_1.output + | bias : [5] : float : 'bias2' + | ------------------------------------------------------ + | Outputs : output : [u1, 5] : float : Layer_3.input +----------------------------------------------------------------------- +Layer_3 | Inputs : weight : [ 6, 5] : float : 'weight3' + | input : [u1, 5] : float : Layer_2.output + | bias : [6] : float : 'bias3' + | $slope : -- : float : 0.01 + | ------------------------------------------------------ + | Outputs : output : [u1, 6] : float : 'output' +----------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_primitive_model_summary_4 b/tests/scripts/summary_txts/test_primitive_model_summary_4 index 0efedb38..b9618aac 100644 --- a/tests/scripts/summary_txts/test_primitive_model_summary_4 +++ b/tests/scripts/summary_txts/test_primitive_model_summary_4 @@ -1,4 +1,3 @@ - Model Info ================================ Backend type : jax @@ -11,11 +10,11 @@ Output keys : None -------------------------------- Constant inputs : None -------------------------------- -Static keys : None +Static keys : left, right -------------------------------- -Trainable keys : left, right +Trainable keys : None -------------------------------- -Total Parameters : >0 +Total Parameters : 0 -------------------------------- @@ -25,8 +24,8 @@ Model Name | Model Keys | ------------------------------------------------------- | Keys : Shapes : Connections : Parameters ====================================================================================== -MatrixMultiply | Inputs : left : [(V1, ...), u1, u2] : 'left' : Unknown - | right : [(V2, ...), u2, u3] : 'right' : Unknown +MatrixMultiply | Inputs : left : [(V1, ...), u1, u2] : 'left' : 0 + | right : [(V2, ...), u2, u3] : 'right' : 0 | ----------------------------------------------------------------- | Outputs : output : [(V3, ...), u1, u3] : 'output_0' : 0 -------------------------------------------------------------------------------------- diff --git a/tests/scripts/summary_txts/test_primitive_model_summary_5 b/tests/scripts/summary_txts/test_primitive_model_summary_5 index 01efc4dd..1c14a68b 100644 --- a/tests/scripts/summary_txts/test_primitive_model_summary_5 +++ b/tests/scripts/summary_txts/test_primitive_model_summary_5 @@ -1,4 +1,3 @@ - Model Info =============================== Backend type : jax @@ -11,11 +10,11 @@ Output keys : None ------------------------------- Constant inputs : None ------------------------------- -Static keys : None +Static keys : in1 ------------------------------- -Trainable keys : in1 +Trainable keys : None ------------------------------- -Total Parameters : >0 +Total Parameters : 0 ------------------------------- @@ -25,7 +24,7 @@ Model Name | Model Keys | -------------------------------------------------------------- | Keys : Shapes : Connections : Parameters ========================================================================================= -Add | Inputs : left : [(V1, ...)] : 'in1' : Unknown +Add | Inputs : left : [(V1, ...)] : 'in1' : 0 | right : [(V2, ...), u1, u2] : 'output_0' : 0 | ------------------------------------------------------------------------ | Outputs : output : [(V3, ...), u3, u4] : 'output_1' : 0 diff --git a/tests/scripts/summary_txts/test_primitive_model_summary_8 b/tests/scripts/summary_txts/test_primitive_model_summary_8 index 7638a0b1..a475eb4c 100644 --- a/tests/scripts/summary_txts/test_primitive_model_summary_8 +++ b/tests/scripts/summary_txts/test_primitive_model_summary_8 @@ -1,4 +1,3 @@ - Model Info ===================================== Backend type : jax @@ -11,11 +10,11 @@ Output keys : output ------------------------------------- Constant inputs : slope ------------------------------------- -Static keys : None +Static keys : in1, left, right ------------------------------------- -Trainable keys : in1, left, right +Trainable keys : None ------------------------------------- -Total Parameters : >0 +Total Parameters : 0 ------------------------------------- @@ -25,12 +24,12 @@ Model Name | Model Keys | ------------------------------------------------------------------------ | Keys : Shapes : Connections : Parameters ======================================================================================================= -MatrixMultiply | Inputs : left : [(V1, ...), u1, u2] : 'left' : Unknown - | right : [(V2, ...), u2, u3] : 'right' : Unknown +MatrixMultiply | Inputs : left : [(V1, ...), u1, u2] : 'left' : 0 + | right : [(V2, ...), u2, u3] : 'right' : 0 | ---------------------------------------------------------------------------------- | Outputs : output : [(V3, ...), u1, u3] : Add.right : 0 ------------------------------------------------------------------------------------------------------- -Add | Inputs : left : [(V4, ...)] : 'in1' : Unknown +Add | Inputs : left : [(V4, ...)] : 'in1' : 0 | right : [(V3, ...), u1, u3] : MatrixMultiply.output : 0 | ---------------------------------------------------------------------------------- | Outputs : output : [(V5, ...), u4, u5] : Sigmoid.input : 0 diff --git a/tests/scripts/summary_txts/test_traincontext_summary_3 b/tests/scripts/summary_txts/test_traincontext_summary_3 index f406b795..3ffa7e39 100644 --- a/tests/scripts/summary_txts/test_traincontext_summary_3 +++ b/tests/scripts/summary_txts/test_traincontext_summary_3 @@ -1,14 +1,13 @@ - Model ----------------------------------------------------------------------------------------------------- Model Name | Model Keys | ---------------------------------------------------------------------- | Keys : Shapes : Types : Connections ===================================================================================================== -Add_0 | Inputs : left : [(V1, ...)] : bool | float | int : 'in1' +Add_0 | Inputs : left : [(V1, ...)] : float : 'in1' | right : [(V2, ...)] : bool | float | int : 'in2' | -------------------------------------------------------------------------------- - | Outputs : output : [(V3, ...)] : bool | float | int : 'output1' + | Outputs : output : [(V3, ...)] : float : 'output1' ----------------------------------------------------------------------------------------------------- Add_1 | Inputs : left : [(V4, ...)] : bool | float | int : '$left' | right : [(V5, ...)] : bool | float | int : '$right_0' diff --git a/tests/scripts/summary_txts/test_traincontext_summary_4 b/tests/scripts/summary_txts/test_traincontext_summary_4 index e19c72b9..8319f151 100644 --- a/tests/scripts/summary_txts/test_traincontext_summary_4 +++ b/tests/scripts/summary_txts/test_traincontext_summary_4 @@ -1,14 +1,13 @@ - Model ------------------------------------------------------------------------ Model Name | Model Keys | ----------------------------------------- | Keys : Types : Connections ======================================================================== -Add_0 | Inputs : left : bool | float | int : 'in1' - | right : bool | float | int : 'in2' +Add_0 | Inputs : left : float : 'in1' + | right : float : 'in2' | --------------------------------------------------- - | Outputs : output : bool | float | int : 'output1' + | Outputs : output : float : 'output1' ------------------------------------------------------------------------ Add_1 | Inputs : left : bool | float | int : '$left' | right : bool | float | int : '$right_0' diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 372a7909..676e93ce 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -140,7 +140,10 @@ def compile_and_compare( } pm = mithril.compile( - model, backend=backend, **compile_kwargs | {"constant_keys": statics} + model, + backend=backend, + **compile_kwargs + | {"constant_keys": statics, "trainable_keys": params.keys()}, ) outputs = pm.evaluate(params=backend_params, data=backend_data) @@ -412,7 +415,7 @@ def test_nan_to_num_1(): def test_linear_1(): model = Linear() - model.input.set_differentiable(True) + model.set_differentiability(input=True) params = {"input": [[1.0], [2.0], [3.0], [4.0]], "weight": [[0.2]], "bias": [0.5]} output_gradients = {"output": [[1.0], [1.0], [1.0], [1.0]]} reference_outputs = {"output": [[0.7], [0.9], [1.1], [1.3]]} diff --git a/tests/scripts/test_c_backend.py b/tests/scripts/test_c_backend.py index 6ef3d790..4bf6fc2c 100644 --- a/tests/scripts/test_c_backend.py +++ b/tests/scripts/test_c_backend.py @@ -38,12 +38,14 @@ def test_cbackend_1(): model, c_backend, shapes={"left": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "right"}, jit=False, ) np_pm = compile( model, np_backend, shapes={"left": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "right"}, jit=False, ) @@ -95,12 +97,14 @@ def test_cbackend_2(file_path: str): c_backend, file_path=file_path, shapes={"left": [5, 5], "left2": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "left2", "right"}, jit=False, ) np_pm = compile( model, np_backend, shapes={"left": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "left2", "right"}, jit=False, ) @@ -160,12 +164,14 @@ def test_cbackend_3(): model, c_backend, shapes={"left": [5, 5], "mul": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "mul", "right"}, jit=False, ) np_pm = compile( model, np_backend, shapes={"left": [5, 5], "right": [5, 5]}, + trainable_keys={"left", "mul", "right"}, jit=False, ) diff --git a/tests/scripts/test_canonicality.py b/tests/scripts/test_canonicality.py index 63e1d151..3a8e70ad 100644 --- a/tests/scripts/test_canonicality.py +++ b/tests/scripts/test_canonicality.py @@ -918,9 +918,9 @@ def test_existing_connection_parent_output_updated_to_internal(): def test_compile_multi_canonical_output_no_exposed_output(): model = Model() - model |= Relu() - model |= Relu() - model |= Relu() + model |= Relu()("input1") + model |= Relu()("input2") + model |= Relu()("input3") backend = ml.JaxBackend() pm = ml.compile(model, backend) diff --git a/tests/scripts/test_codegen.py b/tests/scripts/test_codegen.py index ce05fefb..adb62d3d 100644 --- a/tests/scripts/test_codegen.py +++ b/tests/scripts/test_codegen.py @@ -18,6 +18,7 @@ import mithril from mithril import JaxBackend, MlxBackend, NumpyBackend, TorchBackend +from mithril.framework.logical.model import IOKey from mithril.models import ( Arange, Concat, @@ -47,7 +48,7 @@ def list_full(fill_value, *shapes): @with_temp_file(".py") def test_single_input_primitive(file_path): model = Model() - model += Relu()(input="input", output="output") + model += Relu()(input=IOKey("input", differantiable=True), output="output") model.set_shapes({"input": [1, 2, 3]}) backend = NumpyBackend() @@ -103,8 +104,10 @@ def evaluate(params, data, cache): @with_temp_file(".py") def test_multi_input_primitive(file_path: str): model = Model() - model += Linear()(input="input", weight="w", bias="b", output="output") - model.input.set_differentiable(True) # type: ignore + model += Linear()( + input=IOKey("input", differantiable=True), weight="w", bias="b", output="output" + ) + model.set_differentiability(input=True) model.set_shapes({"input": [1, 2, 3]}) backend = NumpyBackend() @@ -123,11 +126,11 @@ def evaluate(params, data, cache): output_cache = cache["output_cache"] w = params["w"] output_0 = output_0_cache["output"] = transpose(w, None, cache=output_0_cache) - output_1 = output_1_cache["output"] = make_array( - matrix_multiplication(input, output_0, output_1_cache) + output_1 = output_1_cache["output"] = matrix_multiplication( + input, output_0, output_1_cache ) del output_0 - output = output_cache["output"] = make_array(add(output_1, b, output_cache)) + output = output_cache["output"] = add(output_1, b, output_cache) del output_1 return {"output": output} @@ -141,9 +144,9 @@ def evaluate(params, data, cache): input = params["input"] w = params["w"] output_0 = transpose(w, None) - output_1 = make_array(matrix_multiplication(input, output_0)) + output_1 = matrix_multiplication(input, output_0) del output_0 - output = make_array(add(output_1, b)) + output = add(output_1, b) del output_1 return {"output": output} @@ -187,7 +190,12 @@ def evaluate(params, data, cache): @with_temp_file(".py") def test_variadic_input_primitive_1(file_path: str): model = Model() - model += Concat(n=3)(input1="input1", input2="input2", output="output") + model += Concat(n=3)( + input1=IOKey("input1", differantiable=True), + input2=IOKey("input2", differantiable=True), + input3=IOKey("input3", differantiable=True), + output="output", + ) model.set_shapes({"input1": [1, 2, 3]}) backend = NumpyBackend() @@ -202,8 +210,8 @@ def evaluate(params, data, cache): input2 = params["input2"] input3 = params["input3"] output_cache = cache["output_cache"] - output = output_cache["output"] = make_array( - concat(input1, input2, input3, cache=output_cache) + output = output_cache["output"] = concat( + input1, input2, input3, cache=output_cache ) return {"output": output} @@ -216,7 +224,7 @@ def evaluate(params, data, cache): input1 = params["input1"] input2 = params["input2"] input3 = params["input3"] - output = make_array(concat(input1, input2, input3)) + output = concat(input1, input2, input3) return {"output": output} compare_callables(evaluate, eval_func) @@ -313,7 +321,7 @@ def evaluate(params, data, cache): @with_temp_file(".py") def test_default_kwarg_reduction_1(file_path: str): model = Model() - model += Mean() + model += Mean()(input=IOKey("input", differantiable=True)) backend = NumpyBackend() mithril.compile(model, backend, inference=False, jit=False, file_path=file_path) @@ -374,7 +382,7 @@ def evaluate(params, data, cache): @with_temp_file(".py") def test_default_kwarg_reduction_2(file_path: str): model = Model() - model += Mean(axis=3)() + model += Mean(axis=3)(input=IOKey("input", differantiable=True)) backend = NumpyBackend() diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index edfdc987..cf8eee8e 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -474,10 +474,10 @@ def test_axis(): model = Model() relu = LeakyRelu() rob_pow = Power(robust=True) - model += relu(input="input", slope=Tensor(2.3)) + model += relu(input=IOKey("input", differantiable=True), slope=Tensor(2.3)) model += rob_pow( base=relu.output, - exponent=IOKey("exponent", type=Tensor), + exponent=IOKey("exponent", type=Tensor, differantiable=True), threshold=relu.slope, ) @@ -509,7 +509,11 @@ def test_axis_1(): relu = LeakyRelu() rob_pow = Power(robust=True) rob_pow.set_types(base=Tensor, exponent=Tensor) - model += rob_pow(base="base", threshold=Tensor(2.3), exponent="exponent") + model += rob_pow( + base=IOKey("base", differantiable=True), + threshold=Tensor(2.3), + exponent=IOKey("exponent", differantiable=True), + ) model += relu(input=rob_pow.output, slope=rob_pow.threshold) # type: ignore # Check required value transfer occured in logical model # assert relu.conns.get_data("slope").value == 2.3 @@ -800,7 +804,7 @@ def test_static_2(): add_1 = Add() model1 += add_1( left=Tensor([2.0, 3.0]), - right=IOKey("right", type=Tensor), + right=IOKey("right", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model2 += model1 @@ -825,7 +829,10 @@ def test_static_2_set_values(): model1 = Model() model2 = Model() add_1 = Add() - model1 += add_1(right=IOKey("right", type=Tensor), output=IOKey(name="output")) + model1 += add_1( + right=IOKey("right", type=Tensor, differantiable=True), + output=IOKey(name="output"), + ) model1.set_values({add_1.left: Tensor([2.0, 3.0])}) model2 += model1 comp_model = ml.compile(model=model2, backend=NumpyBackend()) @@ -1180,8 +1187,7 @@ def test_static_input_1(): model = Model() add_1 = Add() add_1.set_types(left=Tensor, right=Tensor) - add_1.left.set_differentiable(False) - add_1.right.set_differentiable(False) + ref = np.array(5.0) model += add_1 comp_model = ml.compile( @@ -1203,8 +1209,7 @@ def test_static_input_1_safe_names(): model = Model() add_1 = Add() add_1.set_types(left=Tensor, right=Tensor) - add_1.left.set_differentiable(False) - add_1.right.set_differentiable(False) + model += add_1 with pytest.raises(KeyError) as err: ml.compile(model=model, backend=NumpyBackend(), jit=False) @@ -1219,8 +1224,7 @@ def test_static_input_2(): add_1 = Add() add_1.set_types(left=Tensor, right=Tensor) ref = np.array(5.0) - add_1.left.set_differentiable(False) - add_1.right.set_differentiable(False) + model += add_1() comp_model = ml.compile( model=model, @@ -1244,8 +1248,7 @@ def test_static_input_2_safe_names(): model = Model() add_1 = Add() add_1.set_types(left=Tensor, right=Tensor) - add_1.left.set_differentiable(False) - add_1.right.set_differentiable(False) + model += add_1() with pytest.raises(KeyError) as err: ml.compile( @@ -1264,8 +1267,7 @@ def test_static_input_3(): add_1 = Add() add_1.set_types(left=Tensor, right=Tensor) ref = np.array(5.0) - add_1.left.set_differentiable(False) - add_1.right.set_differentiable(False) + model += add_1() comp_model = ml.compile( model=model, @@ -1308,8 +1310,7 @@ def test_static_input_5(): add_1 = Add() add_1.set_types(left=Tensor, right=Tensor) ref = np.array(5.0) - add_1.left.set_differentiable(False) - add_1.right.set_differentiable(False) + model += add_1(left="input", right="right") comp_model = ml.compile( model=model, @@ -1427,7 +1428,7 @@ def test_static_input_7(): def test_linear_1(): model = Model() lin1 = Linear() - lin1.input.set_differentiable(True) + lin1.set_differentiability(input=True) lin1.set_shapes({"weight": [2, 2], "input": [2, 2]}) model += lin1(input="input", output=IOKey(name="output")) assert_all_backends_device_dtype(model) @@ -1437,7 +1438,7 @@ def test_mlp(): mlp_model = MLP( activations=[Buffer(), LeakyRelu(), Sigmoid()], dimensions=[2, 1, 1] ) - mlp_model.input.set_differentiable(True) + mlp_model.set_differentiability(input=True) mlp_model.set_shapes({"input": [1, 1]}) assert_all_backends_device_dtype(mlp_model) @@ -1447,7 +1448,7 @@ def test_add_1(): add_model = Add() model += add_model( left=Tensor(1), - right=IOKey("right", type=Tensor), + right=IOKey("right", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model.set_shapes({"right": [1, 1, 1]}) @@ -1460,7 +1461,9 @@ def test_composite_1(): shape_model = Shape() index_model = Indexer() red_model = Mean(axis=TBD) - model += add_model(left=Tensor([[[1]]]), right=IOKey("right", type=Tensor)) + model += add_model( + left=Tensor([[[1]]]), right=IOKey("right", type=Tensor, differantiable=True) + ) model += shape_model(input=add_model.output) model += index_model(input=shape_model.output, index=1) model += red_model( @@ -1477,7 +1480,7 @@ def test_composite_1_set_values(): shape_model = Shape() index_model = Indexer() red_model = Mean(axis=TBD) - model += add_model(right=IOKey("right", type=Tensor)) + model += add_model(right=IOKey("right", type=Tensor, differantiable=True)) model.set_values({add_model.left: Tensor([[[1]]])}) model += shape_model(input=add_model.output) model += index_model(input=shape_model.output, index=1) @@ -1497,8 +1500,9 @@ def test_composite_2(): model = Model() conv1 = Convolution2D(kernel_size=2, out_channels=4) leaky_relu = LeakyRelu() - model += conv1(input="input") - conv1.input.set_differentiable(True) + model += conv1(input=IOKey("input", differantiable=True)) + + conv1.set_differentiability(input=True) model += leaky_relu( input=conv1.output, output=IOKey(name="output"), slope=Tensor(0.3) ) @@ -1511,7 +1515,7 @@ def test_composite_2_set_values(): conv1 = Convolution2D(kernel_size=2, out_channels=4) leaky_relu = LeakyRelu() model += conv1(input="input") - conv1.input.set_differentiable(True) + conv1.set_differentiability(input=True) model += leaky_relu( input=conv1.output, output=IOKey(name="output"), slope=NOT_GIVEN ) @@ -1526,7 +1530,7 @@ def test_composite_3(): leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input", stride=(2, 3)) - conv1.input.set_differentiable(True) + conv1.set_differentiability(input=True) model += leaky_relu(input=conv1.output, slope=Tensor(0.3)) model += mean_model(axis=conv1.stride) # assert not isinstance(conv1.cout, NotAvailable) @@ -1541,7 +1545,7 @@ def test_composite_3_set_values(): leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input") - conv1.input.set_differentiable(True) + conv1.set_differentiability(input=True) model.set_values({conv1.stride: (2, 3)}) model += leaky_relu(input=conv1.output, slope=NOT_GIVEN) model.set_values({leaky_relu.slope: Tensor(0.3)}) @@ -1559,7 +1563,7 @@ def test_composite_4(): leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input", stride=(2, 3)) - conv1.input.set_differentiable(True) + conv1.set_differentiability(input=True) model += leaky_relu(input=conv1.output, slope=Tensor(0.3)) model += mean_model(axis=conv1.stride) model.set_shapes({"input": [1, 1, 8, 8]}) @@ -1573,7 +1577,7 @@ def test_composite_4_set_values(): leaky_relu = LeakyRelu() mean_model = Mean(axis=TBD) model += conv1(input="input") - conv1.input.set_differentiable(True) + conv1.set_differentiability(input=True) model.set_values({conv1.stride: (2, 3)}) model += leaky_relu(input=conv1.output, slope=NOT_GIVEN) model.set_values({leaky_relu.slope: Tensor(0.3)}) @@ -1786,7 +1790,7 @@ def test_unused_cached_values_1_set_values(): } model.set_values(config) comp_model = ml.compile( - model=model, backend=(backend := NumpyBackend()), inference=True + model=model, backend=(backend := NumpyBackend()), inference=True, jit=False ) dtype = backend.get_backend_array_type() cache = comp_model.flat_graph.data_store.data_values @@ -1893,7 +1897,7 @@ def test_unused_cached_values_3(): model = Model() linear_model = Linear(dimension=2) model += linear_model(input=Tensor([[3.0], [2.0]]), weight=Tensor([[1.0], [2.0]])) - linear_model.bias.set_differentiable(False) + linear_model.set_differentiability(bias=False) comp_model = ml.compile( model=model, backend=(backend := NumpyBackend()), safe_names=False ) @@ -1936,7 +1940,7 @@ def test_unused_cached_values_3_set_values(): linear_model.weight: Tensor([[1.0], [2.0]]), } ) - linear_model.bias.set_differentiable(False) + linear_model.set_differentiability(bias=False) comp_model = ml.compile( model=model, backend=(backend := NumpyBackend()), safe_names=False ) @@ -2118,11 +2122,13 @@ def test_nontensor_gradient(): relu = Relu() add_model = Add() - model += shape_model(input="input") + model += shape_model(input=IOKey("input", differantiable=True)) model += relu(input="input") model += to_tensor_model(input=shape_model.output, output=IOKey(name="out1")) model += add_model( - left=IOKey("in1", type=Tensor), right=relu.output, output=IOKey(name="out2") + left=IOKey("in1", type=Tensor, differantiable=True), + right=relu.output, + output=IOKey(name="out2"), ) ctx = TrainModel(model) @@ -2164,9 +2170,11 @@ def test_nontensor_gradient_2(): left="", right=to_tensor_model.output, output=IOKey(name="output") ) model += mult_model( - left="", right=IOKey("right1", type=Tensor), output=add_model.left + left="", + right=IOKey("right1", type=Tensor, differantiable=True), + output=add_model.left, ) - model += relu_model(input="in1", output=mult_model.left) + model += relu_model(input=IOKey("in1", differantiable=True), output=mult_model.left) constant_keys = { "input": backend.array([[10.0, 2.0], [1.0, 1.0]]), } @@ -2194,7 +2202,7 @@ def test_nontensor_gradient_3(): model = Model() shape_model = Shape() to_tensor_model = ToTensor() - model += shape_model(input="input") + model += shape_model(input=IOKey("input", differantiable=True)) model += to_tensor_model(input=shape_model.output, output=IOKey(name="output")) ctx = TrainModel(model) ctx.add_loss(Buffer(), input="output", reduce_steps=[Sum()]) @@ -2218,8 +2226,8 @@ def test_numpy_without_shape(): model = Model() add_model = Add() model += add_model( - left=IOKey("left", type=Tensor), - right=IOKey("right", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("right", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model.set_shapes({"left": [], "right": []}) @@ -2248,14 +2256,14 @@ def test_multiple_to_tensor(): model += tt_1 model += add_model( left=model.cout, - right=IOKey("right", type=Tensor), + right=IOKey("right", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model_1 += shp_2 model_1 += tt_2 model_1 += add_model_2( left=model_1.cout, - right=IOKey("right", type=Tensor), + right=IOKey("right", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model_2 += model(input="input") @@ -2275,7 +2283,10 @@ def test_concat_axis_ellipsis_1(): backend = NumpyBackend() model = Model() concat_model = Concat(n=2, axis=TBD) - model += concat_model(input1="input1", input2="input2") + model += concat_model( + input1=IOKey("input1", differantiable=True), + input2=IOKey("input2", differantiable=True), + ) comp_model = ml.compile(model=model, backend=backend, safe_names=False) in1 = backend.array([[2.0]]) @@ -2294,7 +2305,11 @@ def test_concat_axis_ellipsis_2(): backend = NumpyBackend() model = Model() concat_model = Concat(n=2, axis=TBD) - model += concat_model(input1="input1", input2="input2", axis="axis") + model += concat_model( + input1=IOKey("input1", differantiable=True), + input2=IOKey("input2", differantiable=True), + axis="axis", + ) comp_model = ml.compile(model=model, backend=backend) in1 = backend.array([[2.0]]) @@ -2315,7 +2330,9 @@ def test_polyfeatures_degree_ellipsis(): model = Model() poly_feat_model = PolynomialFeatures(degree=TBD) model += poly_feat_model( - input="input", output=IOKey(name="output"), degree="degree" + input=IOKey("input", differantiable=True), + output=IOKey(name="output"), + degree="degree", ) comp_model = ml.compile(model=model, backend=backend) diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index a5a6f90f..470692f3 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -150,6 +150,7 @@ def test_data_store_4(): "bias", "linear_transpose_axes", } + print(pm.flat_graph.data_store.unused_keys) assert pm.flat_graph.data_store.unused_keys == ref_unused_keys @@ -283,8 +284,10 @@ def test_data_store_11(): ) assert pm.flat_graph.data_store.data_values.keys() == {"out"} - assert pm.flat_graph.data_store.runtime_static_keys == set() - assert pm.flat_graph.data_store.intermediate_non_differentiables._table == dict() + assert pm.flat_graph.data_store.runtime_static_keys == {"something"} + assert pm.flat_graph.data_store.intermediate_non_differentiables._table == { + "out2": pm.flat_graph.data_store.all_data["out2"] + } assert pm.flat_graph.data_store.unused_keys == {"left", "right"} infered_value = pm.flat_graph.data_store.data_values["out"] @@ -444,7 +447,7 @@ def test_data_store_15(): add = Add() add.set_types(left=Tensor, right=Tensor) model += add(left="left") - add.right.set_differentiable(False) + model += Sigmoid()(input=add.output, output="output") pm = PhysicalModel( model=model, @@ -464,10 +467,11 @@ def test_data_store_15(): "output_0_cache", "output_cache", } - assert pm.flat_graph.data_store.runtime_static_keys == {"right"} - assert ( - pm.flat_graph.data_store.intermediate_non_differentiables._table.keys() == set() - ) + assert pm.flat_graph.data_store.runtime_static_keys == {"left", "right"} + assert pm.flat_graph.data_store.intermediate_non_differentiables._table.keys() == { + "output_0", + "output", + } assert pm.flat_graph.data_store.unused_keys == set() @@ -478,7 +482,7 @@ def test_data_store_16(): add = Add() add.set_types(left=Tensor, right=Tensor) model += add(left="left") - add.right.set_differentiable(False) + model += Sigmoid()(input=add.output, output=IOKey("output")) model += Relu()(input="in", output=IOKey(name="out")) @@ -498,14 +502,14 @@ def test_data_store_16(): ) assert pm.flat_graph.data_store.data_values.keys() == set() - assert pm.flat_graph.data_store.runtime_static_keys == set() - assert ( - pm.flat_graph.data_store.intermediate_non_differentiables._table.keys() == set() - ) + assert pm.flat_graph.data_store.runtime_static_keys == {"in"} + assert pm.flat_graph.data_store.intermediate_non_differentiables._table.keys() == { + "out" + } assert pm.flat_graph.data_store.unused_keys == { "right", - "left", "output_0", + "left", "output", } @@ -535,15 +539,15 @@ def test_data_store_17(): ) assert pm.flat_graph.data_store.data_values.keys() == set() - assert pm.flat_graph.data_store.runtime_static_keys == set() - assert ( - pm.flat_graph.data_store.intermediate_non_differentiables._table.keys() == set() - ) + assert pm.flat_graph.data_store.runtime_static_keys == {"in"} + assert pm.flat_graph.data_store.intermediate_non_differentiables._table.keys() == { + "out" + } assert pm.flat_graph.data_store.unused_keys == { - "output", + "right", "output_0", "left", - "right", + "output", } diff --git a/tests/scripts/test_differentiablity.py b/tests/scripts/test_differentiablity.py index bd04b48c..07f76e4c 100644 --- a/tests/scripts/test_differentiablity.py +++ b/tests/scripts/test_differentiablity.py @@ -15,30 +15,51 @@ import mithril from mithril import JaxBackend from mithril.framework.common import Tensor -from mithril.models import Add, Buffer, IOKey, Linear, Model, Multiply +from mithril.models import ( + Add, + Buffer, + Equal, + FloorDivide, + Greater, + GreaterEqual, + IOKey, + Less, + LessEqual, + Linear, + Model, + Multiply, + NotEqual, +) + + +def test_buffer(): + model = Model() + buffer = Buffer() + model += buffer(input=IOKey("input", differantiable=True)) + assert model.input.metadata.differentiable # type: ignore -def test_data_linear(): +def test_linear(): model = Linear() - assert model.input.metadata.is_non_diff + assert not model.input.metadata.differentiable -def test_data_linear_compile(): +def test_linear_compile(): model = Model() - model += Linear()(input="input") + model += Linear()(input="input", weight="weight", bias="bias") backend = JaxBackend() pm = mithril.compile(model, backend) assert "input" in pm.flat_graph.runtime_static_keys -def test_convert_input_data_to_trainable(): +def test_input_data_to_trainable(): model = Model() model += Linear()(input="input") model += Linear()(weight=model.input) # type: ignore assert model.input.metadata.differentiable # type: ignore -def test_convert_input_data_to_trainable_compile(): +def test_input_data_to_trainable_compile(): model = Model() model += Linear()(input="input") model += Linear()(weight=model.input) # type: ignore @@ -51,34 +72,39 @@ def test_convert_input_data_to_trainable_compile(): ) -def test_convert_internal_data_to_trainable(): +def test_internal_data_to_trainable(): model = Model() model += Linear()(input="internal_key") model += Linear()(input="input", output=model.internal_key) # type: ignore - assert model.internal_key.metadata.differentiable # type: ignore + + pm = mithril.compile(model, JaxBackend(), jit=False, use_short_namings=False) + assert pm.data["linear_0_matrixmultiply_output"].differentiable # type: ignore + assert pm.data["linear_1_matrixmultiply_output"].differentiable # type: ignore -def test_set_values_data_and_param(): +def test_set_diff_data_and_param(): model = Multiply() model.set_types(left=Tensor, right=Tensor) - model.left.set_differentiable(False) - assert model.left.metadata.is_non_diff - model.left.set_differentiable(True) - assert not model.left.metadata.is_non_diff - model.left.set_differentiable(False) - assert model.left.metadata.is_non_diff + model.set_differentiability(left=False) + assert not model.left.metadata.differentiable + model.set_differentiability(left=True) + assert model.left.metadata.differentiable + model.set_differentiability(left=False) + assert not model.left.metadata.differentiable def test_match_tensor_with_value_data_and_param(): model1 = Multiply() model1.set_types(left=Tensor) - model1.left.set_differentiable(False) - assert model1.left.metadata.is_non_diff + model1.set_differentiability(left=False) + + assert not model1.left.metadata.differentiable model2 = Multiply() model2.set_types(left=Tensor) - model2.left.set_differentiable(True) - assert not model2.left.metadata.is_non_diff + model2.set_differentiability(left=True) + + assert model2.left.metadata.differentiable model = Model() model += model1(left="my_input") @@ -89,29 +115,31 @@ def test_match_tensor_with_value_data_and_param(): def test_match_tensor_with_value_data_and_param_rev(): model2 = Multiply() model2.set_types(left=Tensor) - model2.left.set_differentiable(True) - assert not model2.left.metadata.is_non_diff + model2.set_differentiability(left=True) + + assert model2.left.metadata.differentiable model1 = Multiply() model1.set_types(left=Tensor) - model1.left.set_differentiable(False) - assert model1.left.metadata.is_non_diff + model1.set_differentiability(left=False) + + assert not model1.left.metadata.differentiable model = Model() model += model1(left="my_input") model += model2(left="my_input") - assert not model.my_input.metadata.is_non_diff # type: ignore + assert model.my_input.metadata.differentiable # type: ignore -def test_non_trainability_flow_in_compile(): +def test_diff_inference(): model = Model() buff_model = Buffer() buff_model.set_types(input=Tensor) - buff_model.input.set_differentiable(False) + buff_model.set_differentiability(input=False) model += buff_model(input="input") mult = Multiply() mult.set_types(left=Tensor, right=Tensor) - mult.left.set_differentiable(False) + mult.set_differentiability(left=False) model += mult(left="left", right=model.cout, output="output") backend = JaxBackend() @@ -119,12 +147,16 @@ def test_non_trainability_flow_in_compile(): assert not pm.flat_graph.all_data["output"].differentiable -def test_non_trainability_flow_in_compile_with_data_keys_1(): +def test_diff_inference_constant_key_to_differentiable_input(): model = Model() buff_model = Buffer() model += buff_model(input="input") mult = Multiply() - model += mult(left=IOKey("left", type=Tensor), right=model.cout, output="output") + model += mult( + left=IOKey("left", type=Tensor, differantiable=True), + right=model.cout, + output="output", + ) backend = JaxBackend() pm = mithril.compile( @@ -133,30 +165,34 @@ def test_non_trainability_flow_in_compile_with_data_keys_1(): assert not pm.flat_graph.all_data["output"].differentiable -def test_non_trainability_flow_in_compile_with_data_keys_2(): +def test_diff_inference_data_key_to_differentiable_input(): model = Model() buff_model = Buffer() model += buff_model(input="input") mult = Multiply() - model += mult(left=IOKey("left", type=Tensor), right=model.cout, output="output") + model += mult( + left=IOKey("left", type=Tensor, differantiable=True), + right=model.cout, + output="output", + ) backend = JaxBackend() - pm = mithril.compile(model, backend, data_keys={"input"}) - assert pm.flat_graph.all_data["output"].differentiable + pm = mithril.compile(model, backend, data_keys={"input", "left"}) + assert not pm.flat_graph.all_data["output"].differentiable -def test_non_trainability_flow_in_compile_with_data_keys_3(): +def test_diff_inference_with_data_keys_3(): model = Model() buff_model = Buffer() model += buff_model(input="input", output="buff_out") mult = Multiply() model += mult( - left=IOKey("left", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), right=model.cout, output=IOKey("mult_out"), ) model += Add()( - left=IOKey("left", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), right=buff_model.output, output=IOKey("add_out"), ) @@ -167,11 +203,12 @@ def test_non_trainability_flow_in_compile_with_data_keys_3(): assert pm.flat_graph.all_data["add_out"].differentiable -def test_trainability_flow_in_compile_with_trainable_keys(): +def test_diff_inference_with_trainable_keys(): model = Model() buff_model = Buffer() buff_model.set_types(input=Tensor) - buff_model.input.set_differentiable(False) + buff_model.set_differentiability(input=False) + model += buff_model(input="input", output="buff_out") mult = Multiply() model += mult( @@ -180,9 +217,74 @@ def test_trainability_flow_in_compile_with_trainable_keys(): output=IOKey("mult_out"), ) model += Add()(left="left", right=buff_model.output, output=IOKey("add_out")) - model.left.set_differentiable(False) # type: ignore + model.set_differentiability(left=False) backend = JaxBackend() pm = mithril.compile(model, backend, trainable_keys={"input"}) assert pm.flat_graph.all_data["mult_out"].differentiable assert pm.flat_graph.all_data["add_out"].differentiable + + +def test_diff_inference_floor_div(): + model = Model() + model += FloorDivide()("input", "denom", "output") + + pm = mithril.compile(model, JaxBackend(), inference=True) + + assert not pm.flat_graph.all_data["output"].differentiable + + +def test_diff_inference_relational_ops(): + primitives = [Greater, Less, GreaterEqual, LessEqual, Equal, NotEqual] + + for primitive in primitives: + model = Model() + model += primitive()("input", "denom", "output") + + pm = mithril.compile(model, JaxBackend(), inference=True) + + assert not pm.flat_graph.all_data["output"].differentiable + + +def test_diff_inference_constant_keys_1(): + model = Model() + model += Multiply()(IOKey("input", differantiable=True), "denom", "output") + + pm = mithril.compile(model, JaxBackend(), constant_keys={"denom": 1.0}) + + assert pm.flat_graph.all_data["output"].differentiable + assert not pm.flat_graph.all_data["denom"].differentiable + + +def test_diff_inference_constant_keys_2(): + model = Model() + model += Multiply()(IOKey("input", differantiable=True), "denom", "output") + + backend = JaxBackend() + + pm = mithril.compile( # type: ignore + model, + backend, + constant_keys={"input": backend.ones((4, 4)), "denom": 1.0}, + inference=True, + ) + + assert not pm.flat_graph.all_data["output"].differentiable + # assert not pm.flat_graph.all_data["denom"].differentiable + # assert not pm.flat_graph.all_data["input"].differentiable + + +def test_diff_inference_add(): + model = Model() + model += Add()("left", "right", "output") + assert not model.left.metadata.differentiable # type: ignore + assert not model.right.metadata.differentiable # type: ignore + assert not model.output.metadata.differentiable # type: ignore + + model.set_differentiability(left=False) + + assert not model.left.metadata.differentiable # type: ignore + + model.set_types(left=Tensor) + + assert not model.left.metadata.differentiable # type: ignore diff --git a/tests/scripts/test_extend_template.py b/tests/scripts/test_extend_template.py index 9946a8ff..fa4ea720 100644 --- a/tests/scripts/test_extend_template.py +++ b/tests/scripts/test_extend_template.py @@ -1631,7 +1631,7 @@ def test_index_multiple_slice_3(): def test_tensor_item_with_ellipsis_at_beginning(): - input = IOKey("input", shape=(3, 4, 5)) + input = IOKey("input", shape=(3, 4, 5), differantiable=True) model = Model() buff_model = Buffer() buff_model.set_types(input=Tensor) @@ -1649,7 +1649,7 @@ def test_tensor_item_with_ellipsis_at_beginning(): def test_tensor_item_with_ellipsis_in_middle(): - input = IOKey("input", shape=(2, 3, 4, 5, 6)) + input = IOKey("input", shape=(2, 3, 4, 5, 6), differantiable=True) model = Model() buff_model = Buffer() buff_model.set_types(input=Tensor) @@ -1670,7 +1670,7 @@ def test_tranpose_1(): backend = JaxBackend() model = Model() - input = IOKey("input") + input = IOKey("input", differantiable=True) result = input.transpose() model += Buffer()(input=result, output="output") @@ -1684,7 +1684,7 @@ def test_tranpose_2(): backend = JaxBackend() model = Model() - input = IOKey("input") + input = IOKey("input", differantiable=True) result = input.transpose() model += Buffer()(input=result, output="output") @@ -1703,7 +1703,7 @@ def test_tranpose_3(): input_arr = backend.ones(4, 3, 2) axis = random.shuffle(list(range(input_arr.ndim))) - input = IOKey("input") + input = IOKey("input", differantiable=True) result = input.transpose(axis) model += Buffer()(input=result, output="output") @@ -1722,7 +1722,7 @@ def test_tranpose_4(): input_arr = jnp.ones(8) axis = random.shuffle(list(range(input_arr.ndim))) - input = IOKey("input") + input = IOKey("input", differantiable=True) result = input.transpose(axis) model += Buffer()(input=result, output="output") @@ -1740,7 +1740,7 @@ def test_split_direct(): input_arr = jnp.ones((8, 16)) - input = IOKey("input") + input = IOKey("input", differantiable=True) result = input.split(2, axis=1) model += Buffer()(input=result, output="output") diff --git a/tests/scripts/test_flat_graph.py b/tests/scripts/test_flat_graph.py index 0b9b3d63..ba952dd1 100644 --- a/tests/scripts/test_flat_graph.py +++ b/tests/scripts/test_flat_graph.py @@ -90,14 +90,14 @@ def test_flatgraph_4(): model = Model() model += model_1() model += model_2( - relu_2="", + relu_2="input", output_2=model_1.relu_2, # type: ignore relu_1=model_1.output_2, # type: ignore output_1=ml.IOKey(name="output"), ) pm = ml.compile(model=model, backend=backend) - assert pm.input_keys == {"relu_2"} + assert pm.input_keys == {"input"} assert len(pm.flat_graph.all_source_keys) == 3 assert len(pm.flat_graph.all_target_keys) == 3 diff --git a/tests/scripts/test_flatmodel.py b/tests/scripts/test_flatmodel.py index 1d981cfe..da3536d9 100644 --- a/tests/scripts/test_flatmodel.py +++ b/tests/scripts/test_flatmodel.py @@ -347,8 +347,8 @@ def test_integration_with_all_defined(): pm_long = ml.compile(model, backend, use_short_namings=False) inputs = {"a": backend.array([1, 2, 3]), "b": backend.array([4, 5, 6])} - res_short = pm_short.evaluate(inputs) - res_long = pm_long.evaluate(inputs) + res_short = pm_short.evaluate(data=inputs) + res_long = pm_long.evaluate(data=inputs) expected_res = {"c": backend.array([5, 7, 9], dtype=ml.int64)} @@ -364,14 +364,14 @@ def test_integration_with_some_undefined(): add.set_types(left=Tensor, right=Tensor) model += add(right="b", output="c") - pm_short = ml.compile(model, backend) - pm_long = ml.compile(model, backend, use_short_namings=False) + pm_short = ml.compile(model, backend, safe_names=False) + pm_long = ml.compile(model, backend, use_short_namings=False, safe_names=False) inputs_short = {"left": backend.array((1, 2, 3)), "b": backend.array([4, 5, 6])} inputs_long = {"add_left": backend.array((1, 2, 3)), "b": backend.array([4, 5, 6])} - res_short = pm_short.evaluate(inputs_short) - res_long = pm_long.evaluate(inputs_long) + res_short = pm_short.evaluate(data=inputs_short) + res_long = pm_long.evaluate(data=inputs_long) expected_res = {"c": backend.array([5, 7, 9], dtype=ml.int64)} @@ -383,7 +383,11 @@ def test_integration_multi_level_name_with_lowest_definition(): model2 = Model("adder") add = Add() add.set_types(left=Tensor, right=Tensor) - model2 += add(left="a", right="b", output="c") + model2 += add( + left=IOKey("a", differantiable=True), + right=IOKey("b", differantiable=True), + output=IOKey("c", differantiable=True), + ) model1 = Model(name="model") model1 += model2 @@ -424,8 +428,8 @@ def test_integration_collision_from_different_levels(): backend = JaxBackend(dtype=ml.float64) - pm_short = ml.compile(model, backend) - pm_long = ml.compile(model, backend, use_short_namings=False) + pm_short = ml.compile(model, backend, safe_names=False) + pm_long = ml.compile(model, backend, use_short_namings=False, safe_names=False) input_short = {"d": backend.array([1, 2, 3]), "e": backend.array([4, 5, 6])} input_long = { @@ -433,8 +437,8 @@ def test_integration_collision_from_different_levels(): "middle_e": backend.array([4, 5, 6]), } - res_short = pm_short.evaluate(input_short) - res_long = pm_long.evaluate(input_long) + res_short = pm_short.evaluate(data=input_short) + res_long = pm_long.evaluate(data=input_long) expected_res = {"_e": backend.array([5, 7, 9], dtype=ml.int64)} diff --git a/tests/scripts/test_functions.py b/tests/scripts/test_functions.py index 2cfa84e3..c6eb9999 100644 --- a/tests/scripts/test_functions.py +++ b/tests/scripts/test_functions.py @@ -373,7 +373,7 @@ def test_code_generator_2(file_path: str): eval_func = import_module("tmp." + file_name).evaluate def evaluate(params, data, cache): - input = params["input"] + input = data["input"] return {"output1": input, "output2": input} compare_callables(evaluate, eval_func) @@ -431,6 +431,9 @@ def my_adder(left, right, cache: None): NumpyBackend.register_primitive(my_adder, add_grad) model += MyAdder()(left="left", right="right", output=IOKey(name="output")) + model.set_differentiability(left=True) + model.set_differentiability(right=True) + context = TrainModel(model) context.add_loss( BinaryCrossEntropy(), reduce_steps=[Mean()], input="output", target="target" @@ -525,9 +528,14 @@ def my_adder(left, right): JaxBackend.register_primitive(my_adder) model += MyAdder()(left="left", right="right", output=IOKey(name="output")) + model.set_differentiability(left=True) + model.set_differentiability(right=True) + context = TrainModel(model) add = Add() add.set_types(right=Tensor) + add.set_differentiability(right=True) + context.add_loss( BinaryCrossEntropy(), reduce_steps=[add], input="output", target="target" ) @@ -681,6 +689,9 @@ def test_code_generator_8(file_path: str): model = Model() add = Add() add.set_types(left=Tensor, right=Tensor) + add.set_differentiability(left=True) + add.set_differentiability(right=True) + model += add(left="left", right="right") model += Multiply()(left=add.output, right="right2", output="output") diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index d3c6ace0..c825f677 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -303,7 +303,7 @@ def test_9(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - res = pm.evaluate(params={"input": backend.ones(5, 5)}) + res = pm.evaluate(data={"input": backend.ones(5, 5)}) out1 = res["output"] assert isinstance(out1, torch.Tensor) np.testing.assert_array_equal( @@ -320,7 +320,7 @@ def test_10(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - res = pm.evaluate(params={"input": backend.ones(5, 5)}) + res = pm.evaluate(data={"input": backend.ones(5, 5)}) out = res["output"] assert isinstance(out, torch.Tensor) @@ -339,7 +339,7 @@ def test_11(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - res = pm.evaluate(params={"input": backend.ones(5, 5)}) + res = pm.evaluate(data={"input": backend.ones(5, 5)}) out = res["output"] assert isinstance(out, torch.Tensor) np.testing.assert_array_equal( @@ -355,7 +355,7 @@ def test_12(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - res = pm.evaluate(params={"input": backend.ones(5, 5)}) + res = pm.evaluate(data={"input": backend.ones(5, 5)}) out = res["output"] assert isinstance(out, torch.Tensor) @@ -375,7 +375,7 @@ def test_13(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - res = pm.evaluate(params={"input": backend.ones(5, 5)}) + res = pm.evaluate(data={"input": backend.ones(5, 5)}) out1 = res["output1"] assert isinstance(out1, torch.Tensor) out2 = res["output2"] @@ -700,7 +700,7 @@ def test_iokey_tensor_input_all_args(): backend = TorchBackend() # collect all possible values - possible_names = ["left", None] + possible_names = ["left"] possible_values = [Tensor([[2.0]]), TBD] possible_shapes = [[1, 1], None] possible_expose = [True, False] @@ -755,11 +755,11 @@ def test_iokey_tensor_input_all_args(): # successfully. pm = mithril.compile(model=model, backend=backend) if value is TBD: - params = {"left": backend.array([[2.0]]), "right": backend.array([[3.0]])} + data = {"left": backend.array([[2.0]]), "right": backend.array([[3.0]])} else: - params = {"right": backend.array([[3.0]])} + data = {"right": backend.array([[3.0]])} - outputs = pm.evaluate(params=params) + outputs = pm.evaluate(data=data) assert_results_equal(outputs, ref_outputs) @@ -855,12 +855,12 @@ def test_iokey_scalar_output_all_args(): # successfully. pm = mithril.compile(model=model, backend=backend, inference=True) - params = {"input": backend.ones(2, 3, 4)} + data = {"input": backend.ones(2, 3, 4)} if name is not None: # and expose: ref_outputs = {"output1": (2, 3, 4)} else: ref_outputs = {"output": (2, 3, 4)} - outputs = pm.evaluate(params=params) + outputs = pm.evaluate(data=data) assert_results_equal(outputs, ref_outputs) @@ -952,17 +952,16 @@ def test_iokey_scalar_input_all_args(): # if code reaches this far. It is expected model to be compiled and evaluated # successfully. pm = mithril.compile(model=model, backend=backend, safe_names=False) - params = { + data: dict = { "input": backend.ones(2, 2), } - data = {} if value is TBD: - data = {"axis": 0} + data |= {"axis": 0} if expose and name is not None: - data = {"axis1": 0} + data |= {"axis1": 0} ref_outputs = {"output": backend.ones(2)} - outputs = pm.evaluate(params=params, data=data) + outputs = pm.evaluate(data=data) assert_results_equal(outputs, ref_outputs) @@ -1048,12 +1047,12 @@ def test_iokey_tensor_output_all_args(): # if code reaches this far. It is expected model to be compiled and # evaluated successfully. pm = mithril.compile(model=model, backend=backend) - params = {"left": backend.array([[2.0]]), "right": backend.array([[3.0]])} + data = {"left": backend.array([[2.0]]), "right": backend.array([[3.0]])} if name is not None: # and expose: ref_outputs = {"output1": backend.array([[5.0]])} else: ref_outputs = {"output": backend.array([[5.0]])} - outputs = pm.evaluate(params=params) + outputs = pm.evaluate(data=data) assert_results_equal(outputs, ref_outputs) @@ -1077,8 +1076,12 @@ def test_compare_models_1(): model2 += add(left="input1", right="input2") model2 += multiply(left=add.output, right="input3", output=IOKey("output")) - - compare_evaluate(model1=model1, model2=model2, backend=backend, data={}) + data = { + "input1": backend.ones(5, 5), + "input2": backend.ones(5, 5), + "input3": backend.ones(5, 5), + } + compare_evaluate(model1=model1, model2=model2, backend=backend, data=data) def test_compare_models_2(): @@ -1087,7 +1090,8 @@ def test_compare_models_2(): model1 = Model() linear1 = Linear(dimension=3) - linear1.input.set_differentiable(True) + linear1.set_differentiability(input=True) + linear2 = Linear(dimension=3) model1 += linear1(input="input", output="sub_out") @@ -1096,7 +1100,7 @@ def test_compare_models_2(): model2 = Model() linear1 = Linear(dimension=3) - linear1.input.set_differentiable(True) + linear1.set_differentiability(input=True) linear2 = Linear(dimension=3) model2 += linear1(input="input") @@ -1130,7 +1134,8 @@ def test_compare_models_3(): model2 += sig_model3 model2.set_shapes({"input": [2, 2]}) - compare_evaluate(model1=model1, model2=model2, backend=backend, data={}) + data = {"input": backend.ones(2, 2)} + compare_evaluate(model1=model1, model2=model2, backend=backend, data=data) def test_compare_models_4(): @@ -1157,7 +1162,8 @@ def test_compare_models_4(): model2 += sig_model3 model2.set_shapes({"input": [2, 2]}) - compare_evaluate(model1=model1, model2=model2, backend=backend, data={}) + data = {"input": backend.ones(2, 2)} + compare_evaluate(model1=model1, model2=model2, backend=backend, data=data) def test_compare_models_5(): @@ -1179,7 +1185,8 @@ def test_compare_models_5(): model2 += sigmoid(input="input", output=conn) model2.set_shapes({"input": [2, 2]}) - compare_evaluate(model1=model1, model2=model2, backend=backend, data={}) + data = {"input": backend.ones(2, 2)} + compare_evaluate(model1=model1, model2=model2, backend=backend, data=data) def test_iokey_shape_error_1(): @@ -1221,7 +1228,7 @@ def test_iokey_template_1(): pm = mithril.compile(model=model, backend=backend, jit=False) out = pm.evaluate( - params={"left": backend.array([2.0]), "right": backend.array([3.0])} + data={"left": backend.array([2.0]), "right": backend.array([3.0])} ) expected_result = np.array([8.0]) @@ -1243,7 +1250,7 @@ def test_iokey_template_2(): pm = mithril.compile(model=model, backend=backend, jit=False) res = pm.evaluate( - params={"left": backend.array([2.0]), "right": backend.array([3.0])} + data={"left": backend.array([2.0]), "right": backend.array([3.0])} ) expected_result = np.array([5.0]) @@ -1263,7 +1270,7 @@ def test_iokey_template_3(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - out = pm.evaluate(params={"left": backend.array([2.0])}) + out = pm.evaluate(data={"left": backend.array([2.0])}) expected_result = np.array([5.0]) assert pm.input_keys == {"left", "input"} @@ -1282,7 +1289,7 @@ def test_iokey_template_4(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) - out = pm.evaluate(params={"left": backend.ones((9, 8, 7))}) + out = pm.evaluate(data={"left": backend.ones((9, 8, 7))}) expected_result = 9 assert pm.input_keys == {"left", "index"} @@ -1322,7 +1329,7 @@ def test_iokey_template_6(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) out = res["output"] assert isinstance(out, torch.Tensor) np.testing.assert_almost_equal(out, np.ones((4, 5))) @@ -1339,7 +1346,7 @@ def test_iokey_template_7(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) assert res["output"] == (4, 5) @@ -1357,7 +1364,7 @@ def test_iokey_template_8(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) assert res["output2"] == (4, 5) @@ -1375,7 +1382,7 @@ def test_iokey_template_9(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) assert res["output2"] == (4, 5) @@ -1393,7 +1400,7 @@ def test_iokey_template_10(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) np.testing.assert_equal(res["output1"], np.ones((3, 4, 5))) np.testing.assert_equal(res["output2"], np.ones((3, 4, 5))) @@ -1412,7 +1419,7 @@ def test_iokey_template_11(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) np.testing.assert_equal(res["output1"], np.ones((3, 4, 5))) np.testing.assert_equal(res["output2"], np.ones((3, 4, 5))) @@ -1433,7 +1440,7 @@ def test_iokey_template_12(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) np.testing.assert_equal(res["output"], np.ones((3, 4, 5))) @@ -1450,5 +1457,5 @@ def test_iokey_template_13(): pm._input_keys = {"input"} pm._output_keys = {"output"} - res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) + res = pm.evaluate(data={"input": backend.ones((3, 4, 5))}) assert res["output"] == (4, 5) diff --git a/tests/scripts/test_jittable.py b/tests/scripts/test_jittable.py index 8b7869e1..f1f4eb6d 100644 --- a/tests/scripts/test_jittable.py +++ b/tests/scripts/test_jittable.py @@ -110,7 +110,6 @@ def __init__(self, dimension: int | None = None) -> None: "b": [dimension], } self.set_shapes(shapes) - ... class MyModel2(Model): @@ -243,7 +242,9 @@ def __call__( # type: ignore[override] return ExtendInfo(self, kwargs) model = MyModel(dimension=1) - model += Adder()(left="output", right="r1", output=IOKey(name="o1")) + model += Adder()( + left="output", right=IOKey("r1", differantiable=True), output=IOKey(name="o1") + ) compiled_model = compile( model=model, backend=JaxBackend(), constant_keys=static_inputs, jit=True ) @@ -304,22 +305,29 @@ def test_physical_model_jit_1(): """ model = Model(enforce_jit=False) add1 = Add() - add1.set_types(left=Tensor, right=Tensor) add2 = Add() - add2.set_types(left=Tensor, right=Tensor) - model += add1(left="l1", right="l2", output=IOKey(name="out1")) - model += add2(left="l3", right="l4") + model += add1( + left=IOKey("l1", differantiable=True), + right=IOKey("l2", differantiable=True), + output=IOKey(name="out1"), + ) + model += add2( + left=IOKey("l3", differantiable=True), right=IOKey("l4", differantiable=True) + ) model.enforce_jit = False - input = IOKey(name="input", connections={add1.left, add2.left}, expose=True) + input = IOKey( + name="input", + connections={add1.left, add2.left}, + expose=True, + differantiable=True, + ) model += Item()(input=input) backend = JaxBackend() compiled_model = compile(model=model, backend=backend, jit=False) inputs = compiled_model.randomize_params() output_gradients = {"out1": backend.ones_like(inputs["input"])} - outputs, grads = compiled_model.evaluate_all( - inputs, output_gradients=output_gradients - ) + compiled_model.evaluate_all(inputs, output_gradients=output_gradients) def test_physical_model_jit_2(): @@ -376,14 +384,19 @@ def test_jit_2(): backend = JaxBackend() model = Model(enforce_jit=False) model += (add_model := Add())( - left=IOKey("left", type=Tensor), right=IOKey("right", type=Tensor) + left=IOKey("left", differantiable=True), + right=IOKey("right", differantiable=True), ) in1 = add_model.output out1 = in1.shape out2 = out1.tensor().sum() mean_model = Mean(axis=TBD) model += (to_list := Item())(input=out2) - model += mean_model(input="input", axis=to_list.output, output=IOKey(name="output")) + model += mean_model( + input=IOKey("input", differantiable=True), + axis=to_list.output, + output=IOKey(name="output"), + ) pm = compile(model=model, backend=backend, jit=False) params = { "left": backend.randn(1, 1), @@ -392,7 +405,6 @@ def test_jit_2(): } pm.evaluate(params=params) # TODO: Make required assertions!!! - ... def test_jit_3(): @@ -401,10 +413,9 @@ def test_jit_3(): model += Mean(axis=TBD)(input="input", output=IOKey(name="output"), axis="axis") pm = compile(model=model, backend=backend, jit=False) - inputs = {"input": backend.randn(1, 2, 3, 2, 3, 2, 3, 2)} - data = {"axis": 3} + inputs = {"input": backend.randn(1, 2, 3, 2, 3, 2, 3, 2), "axis": 3} - pm.evaluate(params=inputs, data=data) + pm.evaluate(data=inputs) # type: ignore def test_jit_4(): @@ -414,9 +425,8 @@ def test_jit_4(): pm = compile(model=model, backend=backend, jit=True, constant_keys={"axis": 3}) inputs = {"input": backend.randn(1, 2, 3, 2, 3, 2, 3, 2)} - data = {"axis": 3} - pm.evaluate(params=inputs, data=data) + pm.evaluate(data=inputs) def test_jit_5(): diff --git a/tests/scripts/test_model_to_dict_rtt.py b/tests/scripts/test_model_to_dict_rtt.py index 6247d582..2ccc8f2e 100644 --- a/tests/scripts/test_model_to_dict_rtt.py +++ b/tests/scripts/test_model_to_dict_rtt.py @@ -113,7 +113,7 @@ def test_linear_set_diff(): model = Model() linear = Linear(dimension=42) model += linear(input="input", weight="weight", output=IOKey(name="output")) - linear.weight.set_differentiable(False) + linear.set_differentiability(weight=False) model_dict_created = dict_conversions.model_to_dict(model) model_recreated = dict_conversions.dict_to_model(model_dict_created) @@ -194,7 +194,7 @@ def test_constant_key(): def test_constant_key_2(): model = Model() model += (add := Add())( - left=IOKey("input", type=Tensor), + left=IOKey("input", type=Tensor, differantiable=True), right=IOKey(value=Tensor(3)), output=IOKey(name="output"), ) @@ -834,7 +834,7 @@ def test_set_values_constant_1(): ) model += Linear(1)( weight="weight1", - bias=IOKey(value=Tensor([123]), name="bias1"), + bias=IOKey(value=Tensor([123.0]), name="bias1"), input="input2", output=IOKey(name="output2"), ) @@ -872,7 +872,7 @@ def test_set_values_constant_2(): input="input2", output=IOKey(name="output2"), ) - model.set_values({"bias1": Tensor([123])}) + model.set_values({"bias1": Tensor([123.0])}) model_dict_created = dict_conversions.model_to_dict(model) model_recreated = dict_conversions.dict_to_model(model_dict_created) @@ -922,7 +922,7 @@ def test_set_values_ellipsis_2(): ) lin2 = Linear(1) model.extend(lin2, weight="weight1", bias="bias1", input="input2") - lin2.bias.set_differentiable(False) + lin2.set_differentiability(bias=False) model_dict_created = dict_conversions.model_to_dict(model) model_recreated = dict_conversions.dict_to_model(model_dict_created) diff --git a/tests/scripts/test_primitive_calls.py b/tests/scripts/test_primitive_calls.py index 7c60c5b7..ac1145f1 100644 --- a/tests/scripts/test_primitive_calls.py +++ b/tests/scripts/test_primitive_calls.py @@ -58,9 +58,8 @@ def test_error_not_robust_power_call_threshold_float(): def test_compile_robust_power_call_with_default_threshold(): backend = ml.TorchBackend() pow = Power(robust=True) - # pow.set_types({"base": Tensor, "exponent": Tensor}) pm = ml.compile(pow, backend) - pm.evaluate(params={"base": backend.ones(3, 3), "exponent": backend.ones(3, 3)}) + pm.evaluate(data={"base": backend.ones(3, 3), "exponent": backend.ones(3, 3)}) @pytest.mark.skip( diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index 9dc73bea..12be1278 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -183,6 +183,18 @@ def test_randomized(case: str) -> None: for key, value in static_inputs[init_key].items() } + trainble_keys: set[str] = set() + for input_key in model.input_keys: + if input_key in ["threshold", "cutoff"]: + continue + + if ( + input_key not in static_input_info + and not input_key.startswith("$") + and not model.conns.all[input_key].metadata.is_scalar + ): + trainble_keys.add(input_key) + shapes: dict[str, list[int]] = {} for key, value in input_info.items(): shape = value["shapes"] @@ -200,6 +212,7 @@ def test_randomized(case: str) -> None: model=model, constant_keys=static_inputs[init_key], backend=init_backend, # type: ignore + trainable_keys=trainble_keys, shapes=shapes, jit=True, inference=inference, @@ -246,6 +259,7 @@ def test_randomized(case: str) -> None: model=model, constant_keys=static_inputs[backend.backend_type], backend=backend, # type: ignore[reportArgumentType] + trainable_keys=trainble_keys, shapes=shapes, jit=True, inference=inference, diff --git a/tests/scripts/test_recurrent_models.py b/tests/scripts/test_recurrent_models.py index b7dd7486..9ee864d7 100644 --- a/tests/scripts/test_recurrent_models.py +++ b/tests/scripts/test_recurrent_models.py @@ -218,21 +218,24 @@ def __init__( self |= slice_2(start="", stop=indexer.output) self += tensor_item_2(input="prev_hidden", index=slice_2.output) - self += mult_model_1(left="input", right="w_ih") - self += mult_model_2(left=tensor_item_2.output, right="w_hh") + self += mult_model_1(left="input", right=IOKey("w_ih", differantiable=True)) + self += mult_model_2( + left=tensor_item_2.output, right=IOKey("w_hh", differantiable=True) + ) self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2( - left=sum_model_1.output, right=IOKey("bias_hh", type=Tensor) + left=sum_model_1.output, + right=IOKey("bias_hh", type=Tensor, differantiable=True), ) self += sum_model_3( left=sum_model_2.output, - right=IOKey("bias_ih", type=Tensor), + right=IOKey("bias_ih", type=Tensor, differantiable=True), ) self += tanh(input=sum_model_3.output, output=IOKey("hidden")) - self += mult_model_3(left="hidden", right="w_ho") + self += mult_model_3(left="hidden", right=IOKey("w_ho", differantiable=True)) self += sum_model_4( left=mult_model_3.output, - right=IOKey("bias_o", type=Tensor), + right=IOKey("bias_o", type=Tensor, differantiable=True), output=IOKey("output"), ) @@ -345,15 +348,18 @@ def __init__( ) self |= slice_2(start="", stop=indexer.output) self += tensor_item_2(input="prev_hidden", index=slice_2.output) - self += mult_model_1(left="input", right="w_ih") - self += mult_model_2(left=tensor_item_2.output, right="w_hh") + self += mult_model_1(left="input", right=IOKey("w_ih", differantiable=True)) + self += mult_model_2( + left=tensor_item_2.output, right=IOKey("w_hh", differantiable=True) + ) self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2( - left=sum_model_1.output, right=IOKey("bias_hh", type=Tensor) + left=sum_model_1.output, + right=IOKey("bias_hh", type=Tensor, differantiable=True), ) self += sum_model_3( left=sum_model_2.output, - right=IOKey("bias_ih", type=Tensor), + right=IOKey("bias_ih", type=Tensor, differantiable=True), ) self += tanh(input=sum_model_3.output, output=IOKey("hidden")) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 776e9667..a3007a86 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -371,9 +371,9 @@ def test_shape(): sigmoid2.set_shapes({"input": [5, 6, 8, 9, 10]}) model3 += sigmoid2(input="input2", output=IOKey(name="output2")) - model += model1(input2="", output2=IOKey(name="output")) + model += model1(input2="in2", output2=IOKey(name="output")) model += model2(input1=model1.output1, input2=model1.output2) # type: ignore - model |= model3(input2="", output1=model1.input1, output2=model1.input2) # type: ignore + model |= model3(input2="in3", output1=model1.input1, output2=model1.input2) # type: ignore comp_model = mithril.compile(model, backend=NumpyBackend(dtype=mithril.float64)) assert comp_model.shapes["output"] == [5, 6, 8, 9, 10] @@ -584,7 +584,7 @@ def test_pickle_empty_backend(): ) model = Linear(dimension=5) - model.input.set_differentiable(True) + model.set_differentiability(input=True) model.set_shapes({"input": [5, 5]}) ctx = TrainModel(model) ctx.add_loss(Buffer(), input=model.cout) @@ -792,7 +792,7 @@ def test_canonical_output_compile(): def test_static_key_names_consistency(): model = Model() - model += Add()(left=Tensor(3), right=IOKey(type=Tensor)) + model += Add()(left=Tensor(3), right=IOKey(name="right", type=Tensor)) pm = mithril.compile(model, TorchBackend()) assert {"left", "right"} == pm.input_keys @@ -848,20 +848,18 @@ def test_check_static_1(): model = Model() lin1 = Linear(dimension=1) model += lin1( - input=Tensor([[2, 3], [1, 4]]), - weight=Tensor([[4, 5]]), - bias=Tensor([3]), + input=Tensor([[2.0, 3.0], [1.0, 4.0]]), + weight=Tensor([[4.0, 5.0]]), + bias=Tensor([3.0]), output="output", ) comp_model = compile( model=model, backend=NumpyBackend(), - jit=False, inference=True, ) - # inputs = {"w": np.array([[4.0], [5.0]]), - # "b": np.array([3.0])} + outputs = comp_model.evaluate() ref_out = outputs["output"] assert isinstance(ref_out, np.ndarray) @@ -887,8 +885,8 @@ def test_check_static_3(): model = Model() lin1 = Linear(dimension=1) model += lin1( - input=Tensor([[2, 3], [1, 4]]), - weight=Tensor([[4, 5]]), + input=Tensor([[2.0, 3.0], [1.0, 4.0]]), + weight=Tensor([[4.0, 5.0]]), bias="bias", output="output", ) @@ -989,22 +987,22 @@ def test_cyclic_extension(): output1=model1.cin, output2=IOKey("output"), ) - comp_model = mithril.compile(model=model1, backend=NumpyBackend()) + comp_model = mithril.compile(model=model1, backend=NumpyBackend(), jit=False) inputs = {"input": np.array([[2.0]])} - outputs = comp_model.evaluate(inputs) + outputs = comp_model.evaluate(data=inputs) assert_results_equal(outputs, {"output": np.array([[2.0]])}) def test_canonic_example(): model = Model() - model += LeakyRelu() + model += LeakyRelu()("input") model += LeakyRelu() comp_model = compile(model=model, backend=NumpyBackend()) assert set(comp_model.input_keys) == {"slope_0", "slope_1", "input"} assert set(comp_model.output_keys) == {"output"} inputs = {"input": np.array([[2.0, -1.0]])} assert_results_equal( - comp_model.evaluate(inputs), {"output": np.array([[2.0, -0.0001]])} + comp_model.evaluate(data=inputs), {"output": np.array([[2.0, -0.0001]])} ) @@ -1152,7 +1150,7 @@ def test_train_context_example(): model = Model() model += Linear(1)(input="input", output=IOKey(name="output")) model += Linear(1)(input=model.cout, output=IOKey(name="output2")) - model.input.set_differentiable(True) # type: ignore + model.set_differentiability(input=True) context = TrainModel(model) context.add_loss(Buffer(), [Sum()], input="output2") @@ -1410,17 +1408,17 @@ def test_flatten_dag0(): model = Model() l1 = Linear(10) l5 = Linear(1) - l1.input.set_differentiable(True) - l5.input.set_differentiable(True) + l1.set_differentiability(input=True) + l5.set_differentiability(input=True) model += l1(weight="weight_2") model += (lin1 := Linear(10))(input="") model += (lin2 := Linear(10))(input="") model += (lin3 := Linear(10))(input="") model += l5(input="", output=IOKey(name="output1")) - lin1.input.set_differentiable(True) - lin2.input.set_differentiable(True) - lin3.input.set_differentiable(True) + lin1.set_differentiability(input=True) + lin2.set_differentiability(input=True) + lin3.set_differentiability(input=True) l5.set_shapes({"input": [1, 1]}) model.set_cout(l1.output) @@ -1447,7 +1445,7 @@ def test_geo_mean_1(): backend = TorchBackend() model = Model() model += (lin := Linear(1))(weight="weight2") - lin.input.set_differentiable(True) + lin.set_differentiability(input=True) context = TrainModel(model) context.add_loss(Buffer(), input=model.cout) @@ -1669,7 +1667,8 @@ def test_geomean_evaluate(): }, ) model1.set_shapes({"input": [10, 10, 10]}) - lin1.input.set_differentiable(True) + lin1.set_differentiability(input=True) + ctx1 = TrainModel(model1) ctx1.add_loss( Buffer(), @@ -1702,7 +1701,8 @@ def test_geomean_evaluate(): "output": IOKey("output2"), }, ) - lin2.input.set_differentiable(True) + lin2.set_differentiability(input=True) + ctx2 = TrainModel(model2) ctx2.add_loss( Buffer(), @@ -1801,8 +1801,8 @@ def test_regularization_1(): # Test with single regularization and single reduce (mean) operation model = Model() model += Multiply()( - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output="output", ) @@ -1825,8 +1825,8 @@ def test_regularization_1_sanity_test(): model = Model() model.extend( Multiply(), - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output="output", ) @@ -1850,8 +1850,8 @@ def test_regularization_2(): # Test with single regularization and single reduce (sum) operation model = Model() model += Multiply()( - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output="output", ) @@ -1875,8 +1875,8 @@ def test_regularization_3(): # operations model = Model() model += Multiply()( - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output="output", ) @@ -1904,8 +1904,8 @@ def test_regularization_4(): # Test with single regularization and multiple model with multiple reduce operations model = Model() model += Multiply()( - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model += Multiply()(left="left", right="w", output=IOKey(name="output2")) @@ -1941,8 +1941,8 @@ def test_regularization_5(): # Test with single regularization and multiple model with multiple reduce operations model = Model() model += Multiply()( - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output=IOKey(name="output"), ) model += Multiply()( @@ -2010,7 +2010,7 @@ def test_static_anlaysis_1(): ) model += Add()( left=add1.output, - right=IOKey(type=Tensor), + right=IOKey(name="right2", type=Tensor), output=IOKey(name="output1"), ) @@ -2033,7 +2033,7 @@ def test_static_anlaysis_2(): model += sum1(input=add1.output) model += Add()( left=sum1.output, - right=IOKey(type=Tensor), + right=IOKey(name="right2", type=Tensor), output=IOKey(name="output1"), ) @@ -2048,7 +2048,7 @@ def test_static_anlaysis_2(): ) -def test_static_anlaysis_4(): +def test_static_anlaysis_3(): model = Model() model += (add1 := Add()) add1.set_types(left=Tensor, right=Tensor) @@ -2064,7 +2064,7 @@ def test_static_anlaysis_4(): model.set_cin(add1.left) model.set_cout(mul1.output) - comp_model = mithril.compile(model=model, backend=NumpyBackend()) + comp_model = mithril.compile(model=model, backend=NumpyBackend(), safe_names=False) models = {add1, add2, sum1, sub1, mul1, mat1} _models = {model.submodel for model in models} @@ -2676,8 +2676,8 @@ def test_prune_duplicate_grad(): div2 = Divide() mm2 = MatrixMultiply() mm3 = MatrixMultiply() - model += sig1(input="input1") - model += sig2(input="input2") + model += sig1(input=IOKey("input1", differantiable=True)) + model += sig2(input=IOKey("input2", differantiable=True)) model += log1(input=sig1.output) model += log2(input=sig1.output) model += mm1(left=log1.output, right=log2.output) @@ -3271,7 +3271,7 @@ def forward(data): def test_add_loss_unknown_key(): model = Model() l1 = Linear() - model += l1(input="input", weight="w0") + model += l1(input=IOKey("input", differantiable=True), weight="w0") model += Linear()(input=l1.output, weight="w1", output=IOKey(name="output")) context = TrainModel(model) @@ -3349,7 +3349,7 @@ def test_add_regularization_unknown_key(): def test_add_regularization(): model = Model() l1 = Linear(1) - model += l1(input="input", weight=Tensor([[2]])) + model += l1(input="input", weight=Tensor([[2.0]])) model += Linear()(input=l1.output, weight="w1", output=IOKey(name="output")) context = TrainModel(model) @@ -3612,7 +3612,7 @@ def test_composite_6_extend_from_inputs_connect(): backend = TorchBackend() cm = mithril.compile(model, backend=backend) - cm.evaluate(params={"my_input": backend.array([[[[1.0, 2.0, 3.0]]]])}) + cm.evaluate(data={"my_input": backend.array([[[[1.0, 2.0, 3.0]]]])}) def test_composite_4_extend_from_inputs_connect(): @@ -3629,7 +3629,7 @@ def test_composite_4_extend_from_inputs_connect(): backend = TorchBackend() cm = mithril.compile(model, backend=backend) - cm.evaluate(params={"input1": backend.array([[[[1.0, 2.0, 3.0]]]])}) + cm.evaluate(data={"input1": backend.array([[[[1.0, 2.0, 3.0]]]])}) assert ( relu1.input.data.metadata == relu2.input.data.metadata @@ -3677,7 +3677,7 @@ def test_mlp_last_dimension_prop_2(): ctx.add_loss(AbsoluteError(), input="output", target=Tensor([2.0])) comp_model = mithril.compile(model=ctx, backend=NumpyBackend()) inputs = {"in1": np.array([3.0]), "in2": np.array([2.0])} - outputs = comp_model.evaluate(inputs) + outputs = comp_model.evaluate(data=inputs) output_final_cost = outputs["final_cost"] out = outputs["output"] assert isinstance(output_final_cost, np.ndarray) @@ -3920,8 +3920,8 @@ def test_add_loss_coef(): backend = TorchBackend(dtype=mithril.float64) model = Model() model += Multiply()( - left=IOKey("left", type=Tensor), - right=IOKey("w", type=Tensor), + left=IOKey("left", type=Tensor, differantiable=True), + right=IOKey("w", type=Tensor, differantiable=True), output=IOKey(name="output"), ) @@ -3981,7 +3981,7 @@ def test_cycle_handling_1(): model_2 += Tanh()(input="input1", output=IOKey(name="output1")) model_2 += Sine()(input="input2", output=IOKey(name="output2")) model += model_2( - input2="input", + input2=IOKey("input", differantiable=True), output2=IOKey("output2"), input1="input1", output1=IOKey(name="output"), @@ -4096,7 +4096,9 @@ def test_cycle_handling_2(): model += (gelu5 := Gelu())() - model += model_1(input1="input", input2="", output1=gelu5.input) + model += model_1( + input1=IOKey("input", differantiable=True), input2="", output1=gelu5.input + ) model += model_2( input2=gelu5.output, output2=model_1.input2, # type: ignore @@ -4244,7 +4246,10 @@ def test_cycle_handling_3(): model_2 += Sine()(input="input2", output=IOKey(name="output2")) model += gelu5(input="") model += model_1( - input1="input", slope=IOKey("slope"), input2="", output1=gelu5.input + input1=IOKey("input", differantiable=True), + slope=IOKey("slope"), + input2="", + output1=gelu5.input, ) model += model_2( input2=gelu5.output, @@ -5766,7 +5771,9 @@ def test_deepcopy_4(): model += deepcopy(_model) all_data = get_all_data(model) - compiled_model = mithril.compile(model=model, backend=NumpyBackend()) + compiled_model = mithril.compile( + model=model, backend=NumpyBackend(), safe_names=False + ) unused_data = { compiled_model.data.get(key) for key in compiled_model.flat_graph.unused_keys @@ -6122,6 +6129,7 @@ def test_leaky_relu_trainable_slope(): model = Model() model += LeakyRelu()(input="input", output="output", slope="slope") model.set_types(slope=Tensor) + model.set_differentiability(input=True, slope=True) pm = mithril.compile(model=model, backend=backend) params = {"input": backend.array([-2.0, 2.0]), "slope": backend.array(0.2)} @@ -6538,7 +6546,7 @@ def test_constant_6(): def test_iadd_1(): model = Model() - model += MatrixMultiply()(right="w1") + model += MatrixMultiply()(left="left", right="w1") model += MatrixMultiply()(right="w2") model += MatrixMultiply()(right="w3") model += MatrixMultiply()(right="w4") @@ -6568,7 +6576,7 @@ def test_iadd_2(): model += Sigmoid() model += MatrixMultiply()(left=model.cout, right="w4") - compiled_model = compile(model, JaxBackend()) + compiled_model = compile(model, JaxBackend(), safe_names=False) expected_connections: dict[str, list[str | set[str]]] = { "output_0": ["matrix_multiplication", {"left", "w1"}], @@ -6587,7 +6595,7 @@ def test_iadd_3(): model += (mult := MatrixMultiply())(left=sigmoid.output, right="w4") model.set_cout(mult.output) - compiled_model = compile(model, JaxBackend()) + compiled_model = compile(model, JaxBackend(), safe_names=False) expected_connections: dict[str, list[str | set[str]]] = { "output_2": ["sigmoid", {"input"}], @@ -6609,7 +6617,7 @@ def test_iadd_4(): model += model_sub() model += model_sub2() - compiled_model = compile(model, JaxBackend()) + compiled_model = compile(model, JaxBackend(), safe_names=False) expected_connections: dict[str, list[str | set[str]]] = { "out2_0": ["sigmoid", {"in2"}], @@ -6629,7 +6637,7 @@ def test_iadd_5(): model += model_sub model += model_sub2 - compiled_model = compile(model, JaxBackend()) + compiled_model = compile(model, JaxBackend(), safe_names=False) expected_connections: dict[str, list[str | set[str]]] = { "out1_0": ["sigmoid", {"in1"}], @@ -6675,7 +6683,7 @@ def test_iadd_7(): model += (mult := MatrixMultiply())(left=sigmoid.output, right="w4") model.set_cout(mult.output) - compiled_model = compile(model, JaxBackend()) + compiled_model = compile(model, JaxBackend(), safe_names=False) expected_connections: dict[str, list[str | set[str]]] = { "output_2": ["sigmoid", {"input"}], @@ -6838,8 +6846,8 @@ def __call__( # type: ignore[override] # Compile the model and assert the results pm = mithril.compile(model=model, backend=backend) input = backend.ones((7, 6)) - trainable_keys = {"input": input} - outputs = pm.evaluate(trainable_keys) + data = {"input": input} + outputs = pm.evaluate(data=data) ref_outputs = {"output": backend.ones(7) * 6} assert_results_equal(outputs, ref_outputs) @@ -6921,8 +6929,8 @@ def __call__( # type: ignore[override] # Compile the model and assert the results pm = mithril.compile(model=model, backend=backend, safe_names=False, jit=False) input = backend.ones((7, 6)) - trainable_keys = {"input": input} - outputs = pm.evaluate(trainable_keys) + data = {"input": input} + outputs = pm.evaluate(data=data) ref_outputs = {"output": backend.ones(7) * 6} assert_results_equal(outputs, ref_outputs) diff --git a/tests/scripts/test_set_values.py b/tests/scripts/test_set_values.py index ebe6fef5..c203c56e 100644 --- a/tests/scripts/test_set_values.py +++ b/tests/scripts/test_set_values.py @@ -88,7 +88,9 @@ def test_set_values_scalar_1(): backend = JaxBackend() model = Model() mean_model = Mean(axis=TBD) - model += mean_model(input="input", output=IOKey("output", shape=[2, 2])) + model += mean_model( + input=IOKey("input", differantiable=True), output=IOKey("output", shape=[2, 2]) + ) model.set_values({mean_model.axis: 1}) pm = mithril.compile(model=model, backend=JaxBackend()) @@ -110,7 +112,9 @@ def test_set_values_scalar_1_kwargs_arg(): backend = JaxBackend() model = Model() mean_model = Mean(axis=TBD) - model += mean_model(input="input", output=IOKey("output", shape=[2, 2])) + model += mean_model( + input=IOKey("input", differantiable=True), output=IOKey("output", shape=[2, 2]) + ) mean_model.set_values(axis=1) pm = mithril.compile(model=model, backend=JaxBackend()) @@ -133,7 +137,9 @@ def test_set_values_scalar_2(): model = Model() mean_model = Mean(axis=TBD) model += mean_model( - input="input", output=IOKey("output", shape=[2, 2]), axis="axis1" + input=IOKey("input", differantiable=True), + output=IOKey("output", shape=[2, 2]), + axis="axis1", ) model.set_values({model.axis1: 1}) # type: ignore @@ -157,7 +163,9 @@ def test_set_values_scalar_3(): model = Model() mean_model = Mean(axis=TBD) model += mean_model( - input="input", output=IOKey("output", shape=[2, 2]), axis="axis1" + input=IOKey("input", differantiable=True), + output=IOKey("output", shape=[2, 2]), + axis="axis1", ) model.set_values({"axis1": 1}) diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index 462d078f..7e566acf 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -773,7 +773,7 @@ def test_define_unique_names_1(): model |= KernelizedSVM_0(input1=model.cout) model |= KernelizedSVM_1(input1=model.cout) - lin_0.input.set_differentiable(True) + lin_0.set_differentiability(input=True) name_dict = define_unique_names(model.dag.keys()) assert name_dict == { lin_0: "Linear_0", @@ -919,7 +919,7 @@ def test_physical_summary_2(): model += Linear(dimension=3) model += model1 assert isinstance(model.cin, Connection) - model.cin.set_differentiable(True) + model.set_differentiability({model.cin: True}) comp_model = mithril.compile( model=model, backend=NumpyBackend(), shapes={"input": [5, 5]} @@ -938,15 +938,16 @@ def test_physical_summary_2(): def test_physical_summary_3(): model = Model() model_1 = KernelizedSVM(kernel=RBFKernel()) - model_1.input1.set_differentiable(True) - model_1.input2.set_differentiable(True) + model_1.set_differentiability(input1=True, input2=True) model_2 = MLP( activations=[Sigmoid(), Tanh(), Relu(), LeakyRelu()], dimensions=[3, 4, 5, 6] ) model += model_1 model += model_2 - comp_model = mithril.compile(model=model, backend=JaxBackend(), jit=False) + comp_model = mithril.compile( + model=model, backend=JaxBackend(), jit=False, safe_names=False + ) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True, types=False) @@ -961,8 +962,8 @@ def test_physical_summary_3(): def test_physical_summary_3_logical_with_depth(): model = Model() model_1 = KernelizedSVM(kernel=RBFKernel()) - model_1.input1.set_differentiable(True) - model_1.input2.set_differentiable(True) + model_1.set_differentiability(input1=True, input2=True) + model_2 = MLP( activations=[Sigmoid(), Tanh(), Relu(), LeakyRelu()], dimensions=[3, 4, 5, 6] ) @@ -986,15 +987,14 @@ def test_physical_summary_3_logical_with_depth(): def test_physical_summary_4(): model = Model() model_1 = KernelizedSVM(kernel=RBFKernel()) - model_1.input1.set_differentiable(True) - model_1.input2.set_differentiable(True) + model_1.set_differentiability(input1=True, input2=True) model_1.set_cin("input1") model_2 = MLP( activations=[Sigmoid(), Tanh(), Relu(), LeakyRelu()], dimensions=[3, 4, 5, 6] ) model += model_1 model += model_2 - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(model=model_2, shapes=True, verbose=True, depth=1) @@ -1035,14 +1035,14 @@ def test_physical_model_summary_5(): model += add model += divide model += exp - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True) ref_table = "" with open("tests/scripts/summary_txts/test_physical_model_summary_5") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_physical_model_summary_6(): @@ -1060,7 +1060,7 @@ def test_physical_model_summary_6(): ) random_kernel_model.set_shapes({"input1": ["N", "M"], "input2": ["N", "M"]}) - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True) @@ -1080,7 +1080,9 @@ def test_physical_model_summary_7(): ) random_kernel_model.set_shapes({"input1": ["N", "M"], "input2": ["N", "M"]}) - comp_model = mithril.compile(model=random_kernel_model, backend=JaxBackend()) + comp_model = mithril.compile( + model=random_kernel_model, backend=JaxBackend(), safe_names=False + ) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True) @@ -1108,13 +1110,13 @@ def test_physical_model_summary_8(): model += random_kernel_model model += another_random_model - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=False) ref_table = "" with open("tests/scripts/summary_txts/test_physical_model_summary_8") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_physical_model_summary_9(): @@ -1127,7 +1129,7 @@ def test_physical_model_summary_9(): model += random_kernel_model model += Relu() - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True) @@ -1135,7 +1137,7 @@ def test_physical_model_summary_9(): with open("tests/scripts/summary_txts/test_physical_model_summary_9") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_physical_summary_10(): @@ -1144,7 +1146,9 @@ def test_physical_summary_10(): sig_model2 = Sigmoid() model += sig_model1(input="input", output=IOKey("output1")) model += sig_model2(input="input", output=IOKey("output2")) - comp_model = mithril.compile(model=model, backend=JaxBackend(), jit=False) + comp_model = mithril.compile( + model=model, backend=JaxBackend(), jit=False, safe_names=False + ) with redirect_stdout(StringIO()) as summary: comp_model.summary( verbose=True, shapes=True, symbolic=True, model=sig_model1, types=True @@ -1163,7 +1167,7 @@ def test_physical_summary_11(): sig_model2 = Sigmoid() model += sig_model1(input="input", output=IOKey(name="output1")) model += sig_model2(input="input", output=IOKey(name="output2")) - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True, model=sig_model2) ref_table = "" @@ -1178,13 +1182,13 @@ def test_physical_summary_12(): sig_model2 = Sigmoid() model += sig_model1(input="input", output=IOKey(name="output1")) model += sig_model2(input="input", output=IOKey(name="output2")) - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(verbose=True, shapes=True, symbolic=True) ref_table = "" with open("tests/scripts/summary_txts/test_physical_summary_12") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_physical_summary_13(): @@ -1209,7 +1213,7 @@ def test_physical_summary_14(): model += sig_model1(left="left", right="right", output=IOKey("output1")) model += sig_model2(left="left", right="right", output=IOKey("output2")) comp_model = mithril.compile( - model=model, backend=JaxBackend(), shapes={"left": [3, 4, 5]} + model=model, backend=JaxBackend(), shapes={"left": [3, 4, 5]}, safe_names=False ) with redirect_stdout(StringIO()) as summary: comp_model.summary(model=sig_model2, verbose=True) @@ -1239,7 +1243,7 @@ def test_physical_summary_15(): model += lin_model_4( input="input", weight="weight", bias="b", output=IOKey(name="output4") ) - lin_model_1.input.set_differentiable(True) + lin_model_1.set_differentiability(input=True) comp_model = mithril.compile(model=model, backend=JaxBackend(), jit=False) @@ -1270,7 +1274,7 @@ def test_physical_summary_16(): input="input", weight="weight", bias="b", output=IOKey(name="output3") ) - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(model=add_model_1, verbose=True, types=True) @@ -1292,9 +1296,9 @@ def test_physical_summary_17(): model += lin_model_2(input="input", weight="weight", bias="b", output="output2") model += lin_model_3(input="input", weight="weight", bias="b", output="output3") model.set_cout("output3") - lin_model_1.input.set_differentiable(True) + lin_model_1.set_differentiability(input=True) - comp_model = mithril.compile(model=model, backend=JaxBackend()) + comp_model = mithril.compile(model=model, backend=JaxBackend(), safe_names=False) with redirect_stdout(StringIO()) as summary: comp_model.summary(model=matmul_model_1, verbose=True, types=True) @@ -1309,7 +1313,8 @@ def test_physical_summary_17(): def test_resnet_18_physical_summary(): model = resnet18(1) assert isinstance(model.cin, Connection) - model.cin.set_differentiable(True) + model.set_differentiability({model.cin: True}) + comp_model = mithril.compile(model=model, backend=TorchBackend(), jit=False) with redirect_stdout(StringIO()) as summary: @@ -1717,7 +1722,7 @@ def test_primitive_model_summary_4(): with open("tests/scripts/summary_txts/test_primitive_model_summary_4") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_primitive_model_summary_5(): @@ -1729,7 +1734,7 @@ def test_primitive_model_summary_5(): with open("tests/scripts/summary_txts/test_primitive_model_summary_5") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_primitive_model_summary_6(): @@ -1765,7 +1770,7 @@ def test_primitive_model_summary_8(): with open("tests/scripts/summary_txts/test_primitive_model_summary_8") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_primitive_model_summary_9(): @@ -1868,7 +1873,11 @@ def test_traincontext_summary_3(): add_1.set_types(left=Tensor, right=Tensor) add_2.set_types(left=Tensor, right=Tensor) matmul_1 = MatrixMultiply() - model += add_1(left="in1", right="in2", output=IOKey(name="output1")) + model += add_1( + left=IOKey("in1", differantiable=True), + right="in2", + output=IOKey(name="output1"), + ) model += add_2(left="", output=IOKey(name="output2")) model += matmul_1(left="", output=IOKey(name="output3")) model.set_cin(matmul_1.left) @@ -1897,7 +1906,7 @@ def test_traincontext_summary_3(): with open("tests/scripts/summary_txts/test_traincontext_summary_3") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_traincontext_summary_4(): @@ -1907,7 +1916,11 @@ def test_traincontext_summary_4(): add_1.set_types(left=Tensor, right=Tensor) add_2.set_types(left=Tensor, right=Tensor) matmul_1 = MatrixMultiply() - model += add_1(left="in1", right="in2", output=IOKey(name="output1")) + model += add_1( + left=IOKey("in1", differantiable=True), + right=IOKey("in2", differantiable=True), + output=IOKey(name="output1"), + ) model += add_2(left="", output=IOKey(name="output2")) model += matmul_1(left="", output=IOKey(name="output3")) model.set_cin(matmul_1.left) @@ -1938,7 +1951,7 @@ def test_traincontext_summary_4(): with open("tests/scripts/summary_txts/test_traincontext_summary_4") as f: ref_table = f.read() - assert "\n" + summary.getvalue() == ref_table + assert summary.getvalue() == ref_table def test_traincontext_summary_5(): @@ -1948,7 +1961,11 @@ def test_traincontext_summary_5(): add_1.set_types(left=Tensor, right=Tensor) add_2.set_types(left=Tensor, right=Tensor) matmul_1 = MatrixMultiply() - model += add_1(left="in1", right="in2", output=IOKey(name="output1")) + model += add_1( + left=IOKey("in1", differantiable=True), + right=IOKey("in2", differantiable=True), + output=IOKey(name="output1"), + ) model += add_2(output=IOKey(name="output2")) model += matmul_1(output=IOKey(name="output3")) ctx = TrainModel(model) diff --git a/tests/scripts/test_train_context.py b/tests/scripts/test_train_context.py index f00fe3db..4e0ba410 100644 --- a/tests/scripts/test_train_context.py +++ b/tests/scripts/test_train_context.py @@ -70,7 +70,7 @@ def test_add_loss_case_2(): inputs = {"input": np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])} - model += relu1(input="input") + model += relu1(input=IOKey("input", differantiable=True)) model += relu2(input=relu1.output) model += relu3(input=relu2.output, output=IOKey(name="output")) @@ -147,7 +147,7 @@ def test_add_loss_case_3(): "input": backend.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) } - model += relu1(input="input") + model += relu1(input=IOKey("input", differantiable=True)) model += relu2(input=relu1.output) model += relu3(input=relu2.output, output=IOKey(name="output")) @@ -269,7 +269,7 @@ def test_add_loss_case_8(): relu2 = Relu() relu3 = Relu() - model += relu1(input="input") + model += relu1(input=IOKey("input", differantiable=True)) model += relu2(input=relu1.output, output=IOKey(name="output1")) model += relu3(input=relu1.output, output=IOKey(name="output2")) @@ -302,7 +302,7 @@ def test_add_loss_case_9(): sigmoid2 = Relu() sigmoid3 = Relu() - model += sigmoid1(input="input") + model += sigmoid1(input=IOKey("input", differantiable=True)) model += sigmoid2(input=sigmoid1.output, output=IOKey(name="output1")) model += sigmoid3(input=sigmoid1.output, output=IOKey(name="output2")) diff --git a/tests/scripts/test_tuple_list_args_in_extend.py b/tests/scripts/test_tuple_list_args_in_extend.py index 9940efe6..5346f776 100644 --- a/tests/scripts/test_tuple_list_args_in_extend.py +++ b/tests/scripts/test_tuple_list_args_in_extend.py @@ -27,7 +27,9 @@ def test_tuple_argument_1(): model = Model() add = Add() model += add( - left=IOKey("left", type=Tensor), right=Tensor([3.0, 4, 5]), output="output" + left=IOKey("left", type=Tensor, differantiable=True), + right=Tensor([3.0, 4, 5]), + output="output", ) pm = compile(model=model, backend=backend) @@ -240,7 +242,9 @@ def test_list_argument_1(): model = Model() add = Add() model += add( - left=IOKey("left", type=Tensor), right=Tensor([3.0, 4, 5]), output="output" + left=IOKey("left", differantiable=True), + right=Tensor([3.0, 4, 5]), + output="output", ) pm = compile(model=model, backend=backend) diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 78e974cd..e3475f7d 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -159,11 +159,11 @@ def test_scalar_to_tensor_3(): tensor_1 = ToTensor() tensor_2 = ToTensor() model += tensor_1(input=[[[1]]]) - model += add_1(left=tensor_1.output, right=IOKey("right", type=Tensor)) + model += add_1(left=tensor_1.output, right=IOKey("right", differantiable=True)) model += shp_1(input=add_1.output) model += tensor_2(input=shp_1.output) model += Add()( - left=IOKey("left", type=Tensor), + left=IOKey("left", differantiable=True), right=tensor_2.output, output="output", ) @@ -176,11 +176,11 @@ def test_scalar_to_tensor_3(): shp_2 = Shape() model += add_2( left=IOKey(value=[[[1]]]).tensor(), - right=IOKey("right", type=Tensor), + right=IOKey("right", differantiable=True), ) model += shp_2(input=add_2.output) model += Add()( - left=IOKey("left", type=Tensor), + left=IOKey("left", differantiable=True), right=shp_2.output.tensor(), output="output", ) @@ -889,7 +889,11 @@ def test_connect_type_conv_handling_1(): ) model._extend( mat_mul := MatrixMultiply(), - {"left": con_object, "output": IOKey(name="output")}, + { + "left": con_object, + "right": IOKey(differantiable=True), + "output": IOKey(name="output"), + }, ) mat_mul.set_shapes({"output": output_shape}) model_1 = model @@ -906,7 +910,11 @@ def test_connect_type_conv_handling_1(): ) model._extend( (mat_mul := MatrixMultiply()), - {"left": con_object, "output": IOKey(name="output")}, + { + "left": con_object, + "right": IOKey(differantiable=True), + "output": IOKey(name="output"), + }, ) mat_mul.set_shapes({"output": output_shape}) model_2 = model @@ -923,7 +931,11 @@ def test_connect_type_conv_handling_1(): ) model._extend( (mat_mul := MatrixMultiply()), - {"left": con_object, "output": IOKey(name="output")}, + { + "left": con_object, + "right": IOKey(differantiable=True), + "output": IOKey(name="output"), + }, ) mat_mul.set_shapes({"output": output_shape}) model_3 = model @@ -1085,10 +1097,15 @@ def test_connect_7(): model = Model() add_model_1 = Add() add_model_2 = Add() - add_model_1.set_types(left=Tensor, right=Tensor) - add_model_2.set_types(left=Tensor, right=Tensor) - model += add_model_1(left="left", right="right", output=IOKey(name="output2")) - model += add_model_2(left="left1", right="right1") + model += add_model_1( + left=IOKey("left", differantiable=True), + right=IOKey("right", differantiable=True), + output=IOKey(name="output2"), + ) + model += add_model_2( + left=IOKey("left1", differantiable=True), + right=IOKey("right1", differantiable=True), + ) conn = IOKey( connections={add_model_2.output, model.right}, # type: ignore @@ -1139,10 +1156,15 @@ def test_connect_7_expose_output(): model = Model() add_model_1 = Add() add_model_2 = Add() - add_model_1.set_types(left=Tensor, right=Tensor) - add_model_2.set_types(left=Tensor, right=Tensor) - model += add_model_1(left="left", right="right", output=IOKey(name="output2")) - model += add_model_2(left="left1", right="right1") + model += add_model_1( + left=IOKey("left", differantiable=True), + right=IOKey("right", differantiable=True), + output=IOKey(name="output2"), + ) + model += add_model_2( + left=IOKey("left1", differantiable=True), + right=IOKey("right1", differantiable=True), + ) conns = {add_model_2.output, model.right} # type: ignore conn = IOKey(name="abcd", expose=True, connections=conns) # type: ignore model += (buf := Buffer())(input=conn, output=IOKey(name="output")) @@ -1195,11 +1217,14 @@ def test_connect_8(): model = Model() add_model_1 = Add() add_model_2 = Add() - add_model_1.set_types(left=Tensor, right=Tensor) - add_model_2.set_types(left=Tensor, right=Tensor) - model += add_model_1(left="left", right="right") + model += add_model_1( + left=IOKey("left", differantiable=True), + right=IOKey("right", differantiable=True), + ) model += add_model_2( - left=add_model_1.output, right="right1", output=IOKey(name="output1") + left=add_model_1.output, + right=IOKey("right1", differantiable=True), + output=IOKey(name="output1"), ) conn = IOKey( connections={add_model_1.output, model.right1}, # type: ignore @@ -1349,7 +1374,7 @@ def test_tensor_to_scalar_4(): # Auto conversion auto_model += Relu()(input="input") auto_model += (shp := Shape()) - auto_model += Add()(left=shp.output.tensor(), right=IOKey(type=Tensor)) + auto_model += Add()(left=shp.output.tensor(), right=IOKey(differantiable=True)) # Manuel conversion manual_model = Model() @@ -1357,7 +1382,7 @@ def test_tensor_to_scalar_4(): manual_model += Relu()(input="input") manual_model += Shape() manual_model += ToTensor() - manual_model += Add()(right=IOKey(type=Tensor)) + manual_model += Add()(right=IOKey(differantiable=True)) backend = TorchBackend() @@ -1418,8 +1443,8 @@ def test_coercion_1(): add_model_1 = Add() add_model_2 = Add() - model += reduce_model_1(input="input1", axis="axis1") - model += reduce_model_2(input="input2", axis="axis2") + model += reduce_model_1(input=IOKey("input1", differantiable=True), axis="axis1") + model += reduce_model_2(input=IOKey("input2", differantiable=True), axis="axis2") model += add_model_1( left=reduce_model_1.axis.tensor(), right=reduce_model_2.axis.tensor() ) @@ -1459,8 +1484,8 @@ def test_coercion_2(): reduce_model_1 = Sum(axis=TBD) reduce_model_2 = Sum(axis=TBD) l_relu = LeakyRelu() - model += reduce_model_1(input="input1", axis="axis1") - model += reduce_model_2(input="input2", axis="axis2") + model += reduce_model_1(input=IOKey("input1", differantiable=True), axis="axis1") + model += reduce_model_2(input=IOKey("input2", differantiable=True), axis="axis2") axis1 = reduce_model_1.axis.tensor().sum() axis2 = reduce_model_2.axis.tensor().sum() @@ -1511,7 +1536,9 @@ def test_coercion_3(): right=IOKey(value=[0, 1]).tensor(), ) model += (to_list := TensorToList())(input=add_model.output) - model += reduce_model(input="input", axis=to_list.output, output="output") + model += reduce_model( + input=IOKey("input", differantiable=True), axis=to_list.output, output="output" + ) pm = compile(model=model, backend=backend, jit=False) params = {"input": backend.ones(1, 2, 3, 4, 5)} @@ -1537,7 +1564,9 @@ def test_coercion_4(): right=IOKey(value=[0, 1]).tensor(), ) model += (to_list := TensorToList())(input=add_model.output) - model += reduce_model(input="input", axis=to_list.output, output="output") + model += reduce_model( + input=IOKey("input", differantiable=True), axis=to_list.output, output="output" + ) pm = compile(model=model, backend=backend, jit=False) @@ -1556,7 +1585,7 @@ def test_coercion_5(): model = Model(enforce_jit=False) add = Add() to_list = TensorToList() - model += add(left=IOKey("left", type=Tensor), right=Tensor([2.0])) + model += add(left=IOKey("left", differantiable=True), right=Tensor([2.0])) model += to_list(input=add.output) model += Buffer()(input=to_list.output.tensor(), output="output") @@ -1605,12 +1634,9 @@ def test_tensor_to_scalar_template_2(): buff_model_1 = Buffer() buff_model_2 = Buffer() buff_model_3 = Buffer() - buff_model_1.set_types(input=Tensor) - buff_model_2.set_types(input=Tensor) - buff_model_3.set_types(input=Tensor) - model += buff_model_1(input="input1") - model += buff_model_2(input="input2") - model += buff_model_3(input="input3") + model += buff_model_1(input=IOKey("input1", differantiable=True)) + model += buff_model_2(input=IOKey("input2", differantiable=True)) + model += buff_model_3(input=IOKey("input3", differantiable=True)) in1 = buff_model_1.output in2 = buff_model_2.output