diff --git a/mithril/backends/with_autograd/common_primitives.py b/mithril/backends/with_autograd/common_primitives.py index fc441777..50e3a064 100644 --- a/mithril/backends/with_autograd/common_primitives.py +++ b/mithril/backends/with_autograd/common_primitives.py @@ -41,7 +41,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "primitive_slice", "buffer", "permute_tensor", @@ -160,12 +159,6 @@ def square(input: DataType): return input * input -def tensor_slice( - input: DataType, start: int | None, stop: int | None, step: int | None -): - return input[start:stop:step] - - def buffer(input: DataType): return input diff --git a/mithril/backends/with_autograd/jax_backend/ops.py b/mithril/backends/with_autograd/jax_backend/ops.py index 56488fb1..8c338dca 100644 --- a/mithril/backends/with_autograd/jax_backend/ops.py +++ b/mithril/backends/with_autograd/jax_backend/ops.py @@ -64,7 +64,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, transpose, @@ -192,7 +191,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_autograd/mlx_backend/ops.py b/mithril/backends/with_autograd/mlx_backend/ops.py index 542e1953..ecb14fe3 100644 --- a/mithril/backends/with_autograd/mlx_backend/ops.py +++ b/mithril/backends/with_autograd/mlx_backend/ops.py @@ -62,7 +62,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, transpose, @@ -164,7 +163,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_autograd/torch_backend/ops.py b/mithril/backends/with_autograd/torch_backend/ops.py index 71cbffa3..9b4261b0 100644 --- a/mithril/backends/with_autograd/torch_backend/ops.py +++ b/mithril/backends/with_autograd/torch_backend/ops.py @@ -63,7 +63,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, tuple_converter, @@ -181,7 +180,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_manualgrad/common_primitives.py b/mithril/backends/with_manualgrad/common_primitives.py index 46afc63e..a67fa9c3 100644 --- a/mithril/backends/with_manualgrad/common_primitives.py +++ b/mithril/backends/with_manualgrad/common_primitives.py @@ -43,7 +43,6 @@ "squared_error", "transpose", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", @@ -165,16 +164,6 @@ def square(input: DataType, cache: CacheType = None): return input * input -def tensor_slice( - input: DataType, - start: int | None, - stop: int | None, - step: int | None, - cache: CacheType = None, -): - return input[start:stop:step] - - def buffer(input: DataType, cache: CacheType = None): return input diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index 563dec24..cd29c29f 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -63,7 +63,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, transpose, @@ -182,7 +181,6 @@ "squared_error", "transpose", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py index a5232ebb..d4584aa5 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py @@ -76,7 +76,6 @@ "softplus_grad", "gelu_grad", "stop_gradient_grad", - "tensor_slice_grad", "tensor_item_grad", "permute_tensor_grad", "transpose_grad", @@ -810,19 +809,6 @@ def stop_gradient_grad( return np.zeros_like(output_gradient) -def tensor_slice_grad( - output_gradient: np.ndarray[Any, Any], - cache: CacheType, - idx: int, - *inputs: np.ndarray[Any, Any], -) -> np.ndarray[Any, Any]: - verify_shapes(inputs, idx, non_differentiables=[1, 2, 3, 4]) - input, start, stop, step = inputs - grad = np.zeros_like(input) - grad[start:stop:step] = output_gradient - return grad - - def tensor_item_grad( output_gradient: np.ndarray[Any, Any], cache: CacheType, diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index dceff75e..7ef655ba 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -87,7 +87,6 @@ "scalar_item_constraints", "to_tuple_constraints", "tensor_item_constraints", - "tensor_slice_constraints", "tensor_to_list_type_constraint", "reduce_type_constraint", "type_constraints", @@ -3466,36 +3465,6 @@ def tensor_item_constraint_helper( return input_unis, output_unis, status, current_index -def tensor_slice_constraints( - output: Tensor, input: Tensor, start: Scalar, stop: Scalar, step: Scalar -) -> ConstrainResultType: - assert output._temp_shape is not None, "Output shape of TensorSlice is not set!" - assert input._temp_shape is not None, "Input shape of TensorSlice is not set!" - output_shape: ShapeRepr = output._temp_shape - input_shape: ShapeRepr = input._temp_shape - updated_symbols = Updates() - status = False - - if input_shape.prefix and output_shape.prefix: - in_uni, out_uni = input_shape[0], output_shape[0] - if in_uni.value is not None and out_uni.value is not None: - status = True - else: - if ( - start.value is not TBD - and stop.value is not TBD - and step.value is not TBD - and in_uni.value is not None - ): - slc = slice(start.value, stop.value, step.value) - out_val = len(list(range(in_uni.value))[slc]) - out_uni.set_value(out_val) - updated_symbols.add(out_uni) - status = True - - return status, updated_symbols - - def split_constraints(output: Tensor, input: Tensor, split_size: Scalar, axis: Scalar): status = False split_size_val = split_size.value diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 051d9c94..c941aee6 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -47,7 +47,6 @@ slice_constraints, split_constraints, tensor_item_constraints, - tensor_slice_constraints, tensor_to_list_constraints, tensor_to_list_type_constraint, to_list_constraints, @@ -101,7 +100,6 @@ "ShiftLeft", "ShiftRight", "TensorItem", - "TensorSlice", "ArgMax", "ArgMin", "Cast", @@ -606,54 +604,6 @@ def __call__( # type: ignore[override] ) -class TensorSlice(PrimitiveModel): - input: Connection - start: Connection - stop: Connection - step: Connection - output: Connection - - def __init__( - self, - name: str | None = None, - start: int | None | ToBeDetermined = None, - stop: int | None | ToBeDetermined = None, - step: int | None | ToBeDetermined = None, - input: TensorValueType | ToBeDetermined = TBD, - ) -> None: - self.factory_args = {"start": start, "stop": stop, "step": step} - super().__init__( - formula_key="tensor_slice", - name=name, - output=BaseKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - input=BaseKey(shape=["b", ("Var1", ...)], type=GenericTensorType), - start=BaseKey(type=int | None, value=start), - stop=BaseKey(type=int | None, value=stop), - step=BaseKey(type=int | None, value=step), - ) - self.factory_inputs = {"input": input} - - self._set_constraint( - fn=tensor_slice_constraints, - keys=[PrimitiveModel.output_key, "input", "start", "stop", "step"], - ) - self._set_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] - ) - - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - start: ConnectionType = NOT_GIVEN, - stop: ConnectionType = NOT_GIVEN, - step: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__( - input=input, start=start, stop=stop, step=step, output=output - ) - - class Item(PrimitiveModel): input: Connection output: Connection @@ -718,50 +668,6 @@ def __call__( # type: ignore[override] return super().__call__(input=input, index=index, output=output) -class TensorItem(PrimitiveModel): - input: Connection - index: Connection - output: Connection - - def __init__( - self, - name: str | None = None, - index: int | ToBeDetermined = TBD, - input: TensorValueType | ToBeDetermined = TBD, - ) -> None: - super().__init__( - formula_key="tensor_item", - name=name, - output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), - input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), - index=BaseKey( - type=int - | slice - | EllipsisType - | None - | tuple[int | slice | EllipsisType | None, ...], - value=index, - ), - ) - self.factory_inputs = {"input": input, "index": index} - - self._set_constraint( - fn=tensor_item_constraints, - keys=[PrimitiveModel.output_key, "input", "index"], - ) - self._set_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] - ) - - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - index: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, index=index, output=output) - - class ToTensor(PrimitiveModel): input: Connection output: Connection @@ -1523,9 +1429,9 @@ class Slice(PrimitiveModel): def __init__( self, - start: int | None | ToBeDetermined = 0, - stop: int | None | ToBeDetermined = None, - step: int | None | ToBeDetermined = None, + start: int | None | ToBeDetermined = TBD, + stop: int | None | ToBeDetermined = TBD, + step: int | None | ToBeDetermined = TBD, name: str | None = None, ): super().__init__( @@ -1550,3 +1456,47 @@ def __call__( # type: ignore[override] output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__(start=start, stop=stop, step=step, output=output) + + +class TensorItem(PrimitiveModel): + input: Connection + index: Connection + output: Connection + + def __init__( + self, + name: str | None = None, + index: int | ToBeDetermined = TBD, + input: TensorValueType | ToBeDetermined = TBD, + ) -> None: + super().__init__( + formula_key="tensor_item", + name=name, + output=BaseKey(shape=[("Var2", ...)], type=GenericTensorType), + input=BaseKey(shape=[("Var1", ...)], type=GenericTensorType), + index=BaseKey( + type=int + | slice + | EllipsisType + | None + | tuple[int | slice | EllipsisType | None, ...], + value=index, + ), + ) + self.factory_inputs = {"input": input, "index": index} + + self._set_constraint( + fn=tensor_item_constraints, + keys=[PrimitiveModel.output_key, "input", "index"], + ) + self._set_constraint( + fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + ) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + index: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, index=index, output=output) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index e9d976ea..83b8cc4a 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -86,7 +86,6 @@ Subtract, Sum, TensorItem, - TensorSlice, TensorToList, ToList, ToTensor, @@ -143,7 +142,6 @@ coercion_table: dict[tuple[str, type[Tensor] | type[Scalar]], type[PrimitiveModel]] = { ("get_item", Tensor): TensorItem, ("get_item", Scalar): ScalarItem, - ("slice", Tensor): TensorSlice, ("slice", Scalar): PrimitiveSlice, } diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index d7e8340e..ed7e8443 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -512,6 +512,9 @@ def _pre_compile( self.jacobian_keys = jacobian_keys self.ignore_grad_keys: set[str] = set() + # Set given shapes. + self.data_store.set_shapes(shapes) + for node in self._flat_graph.nodes.values(): conn_data = node.model.conns.get_connection("output") assert conn_data is not None @@ -552,9 +555,6 @@ def _pre_compile( self.data_store.constraint_solver(updates) - # Set given shapes. - self.data_store.set_shapes(shapes) - # Set given static keys self.data_store.set_static_keys(constant_keys) diff --git a/mithril/models/models.py b/mithril/models/models.py index 47ad5b2a..158f64b9 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -51,10 +51,11 @@ ScalarItem, Shape, Size, + Slice, Sqrt, Subtract, Sum, - TensorSlice, + TensorItem, Transpose, Variance, ) @@ -1375,7 +1376,10 @@ def __init__( shape = Shape() scalar_item = ScalarItem() - slice_model = TensorSlice(stop=TBD) + slice_1 = Slice(stop=None, step=None) + slice_2 = Slice(start=None, step=None) + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() mult_model_1 = Linear(use_bias=False) mult_model_2 = Linear(use_bias=False) mult_model_3 = Linear(use_bias=False) @@ -1384,13 +1388,15 @@ def __init__( self += shape(input="input") self += scalar_item(input=shape.output, index=0) - self += TensorSlice(start=TBD)( + self += slice_1(start=scalar_item.output) + self += tensor_item_1( input="prev_hidden", - start=scalar_item.output, + index=slice_1.output, output=IOKey(name="hidden_compl"), ) - self += slice_model(input="prev_hidden", stop=scalar_item.output) - self += mult_model_1(input=slice_model.output, weight="w_hh") + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) + self += mult_model_1(input=tensor_item_2.output, weight="w_hh") self += mult_model_2(input="input", weight="w_ih") self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2(left=sum_model_1.output, right="bias_h") @@ -1507,40 +1513,52 @@ def __init__( cell_body = LSTMCellBody() shape_model = Shape() scalar_item = ScalarItem() - slice_model_1 = TensorSlice(stop=TBD) - slice_model_2 = TensorSlice(stop=TBD) - slice_model_3 = TensorSlice(start=TBD) - slice_model_4 = TensorSlice(stop=TBD) + + slice_1 = Slice(start=None, step=None) + slice_2 = Slice(start=None, step=None) + slice_3 = Slice(stop=None, step=None) + slice_4 = Slice(start=None, step=None) + slice_5 = Slice(stop=None, step=None) + + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() + tensor_item_3 = TensorItem() + tensor_item_4 = TensorItem() + tensor_item_5 = TensorItem() self += shape_model(input="input") self += scalar_item(input=shape_model.output, index=0) # Forget gate processes. - self += slice_model_1(input="prev_cell", stop=scalar_item.output) - self += slice_model_2(input="prev_hidden", stop=scalar_item.output) + self += slice_1(stop=scalar_item.output) + self += tensor_item_1(input="prev_cell", index=slice_1.output) + + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) body_kwargs: dict[str, ConnectionType] = { key: key for key in cell_body._input_keys if key[0] != "$" } - body_kwargs["prev_cell"] = slice_model_1.output - body_kwargs["prev_hidden"] = slice_model_2.output + body_kwargs["prev_cell"] = tensor_item_1.output + body_kwargs["prev_hidden"] = tensor_item_2.output self += cell_body(**body_kwargs) - self += slice_model_3( - input=cell_body.output, - start=scalar_item.output, - output=IOKey(name="hidden"), + self += slice_3(start=scalar_item.output) + self += tensor_item_3( + input=cell_body.output, index=slice_3.output, output=IOKey(name="hidden") ) - self += slice_model_4( - input=cell_body.output, stop=scalar_item.output, output=IOKey(name="cell") + self += slice_4(stop=scalar_item.output) + self += tensor_item_4( + input=cell_body.output, index=slice_4.output, output=IOKey(name="cell") ) # Slice complement process. - self += TensorSlice(start=TBD)( + self += slice_5(start=scalar_item.output) + self += tensor_item_5( input="prev_hidden", - start=scalar_item.output, + index=slice_5.output, output=IOKey(name="hidden_compl"), ) # Final output. @@ -1801,7 +1819,8 @@ def __init__( # current time step. shape_model = Shape() item_model = ScalarItem() - slice_model = TensorSlice(stop=TBD) + slice_model = Slice(start=None, step=None) + tensor_item = TensorItem() self += shape_model(input=f"target{idx}") self += item_model(input=shape_model.output, index=0) @@ -1817,9 +1836,10 @@ def __init__( # of previous time step as inputs to the current time step. slice_input_1 = getattr(prev_cell, prev_cell.out_key) - self += slice_model(input=slice_input_1, stop=item_model.output) + self += slice_model(stop=item_model.output) + self += tensor_item(input=slice_input_1, index=slice_model.output) - input_kwargs = {"input": slice_model.output} + input_kwargs = {"input": tensor_item.output} output_kwargs = {cell_type.out_key: IOKey(name=f"output{idx}")} self += current_cell( diff --git a/tests/json_files/models_directed_test.json b/tests/json_files/models_directed_test.json index 30f062bc..8aab8739 100644 --- a/tests/json_files/models_directed_test.json +++ b/tests/json_files/models_directed_test.json @@ -4120,84 +4120,6 @@ } }, - "test_tensor_slice_1": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 1, - "step": null - } - }, - "inputs": { - "input": [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0] - ] - }, - - "output_grads": { - "output": [ - [5.0, 6.0] - ] - }, - "results": { - "eval": { - "output": [ - [1.0, 2.0] - ] - }, - "grad": { - "input": [[5.0, 6.0], - [0.0, 0.0], - [0.0, 0.0]] - } - } - }, - - "test_tensor_slice_2": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 2, - "step": null - } - }, - "inputs": { - "input": [ - [[1.0, 2.0]], - [[3.0, 4.0]], - [[5.0, 6.0]] - ] - }, - - "output_grads": { - "output": [ - [[3.0, 0.0]], - [[2.0, 1.0]] - ] - }, - "results": { - "eval": { - "output": [ - [[1.0, 2.0]], - [[3.0, 4.0]] - ] - }, - "grad": { - "input": [ - [[3.0, 0.0]], - [[2.0, 1.0]], - [[0.0, 0.0]] - ] - } - } - }, - "test_tanh_1": { "model": { diff --git a/tests/json_files/primitives_directed_test.json b/tests/json_files/primitives_directed_test.json index 43ca76d8..b26fb856 100644 --- a/tests/json_files/primitives_directed_test.json +++ b/tests/json_files/primitives_directed_test.json @@ -1023,67 +1023,6 @@ } } }, - "test_tensor_slice_1": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 1, - "step": null - } - }, - "inputs": { - "input": [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0] - ] - }, - - "output_grad": [[5.0, 6.0]], - - "results": { - "eval": [[1.0, 2.0]], - "grad": { - "input": [[5.0, 6.0], - [0.0, 0.0], - [0.0, 0.0]] - } - } - }, - - "test_tensor_slice_2": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 2, - "step": null - } - }, - "inputs": { - "input": [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0] - ] - }, - - "output_grad": [[3.0, 0.0], [2.0, 1.0]], - - "results": { - "eval": [[1.0, 2.0], [3.0, 4.0]], - "grad": { - "input": [ - [3.0, 0.0], - [2.0, 1.0], - [0.0, 0.0] - ] - } - } - }, "test_tanh_1": { "model": "Tanh", "inputs": { diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index 38b1595e..82c0609b 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -2055,22 +2055,6 @@ }, "iterations": 5 }, - "test_tensor_slice": { - "model": { - "name": "TensorSlice", - "randomized_args": { - "start": [1,3], - "stop": [20, 30], - "step": [1, 4] - } - }, - "input_info": { - "input": { - "shapes": [[35,40], [3,3], [3,3], [3,3], [3,3]] - } - }, - "iterations": 20 - }, "test_where": { "model": { "name": "Where" diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index f6e03a14..066b7fe9 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -40,6 +40,7 @@ Greater, GreaterEqual, GroupNorm, + IOKey, IsNan, Less, LessEqual, @@ -50,6 +51,7 @@ LogicalOr, LogicalXOr, Minus, + Model, NanToNum, NormModifier, NotEqual, @@ -65,6 +67,7 @@ Slice, SquaredError, Squeeze, + TensorItem, ToList, ToTensor, ToTuple, @@ -3499,3 +3502,75 @@ def test_slice_all_keys_given_all_three_parts(): tolerances=1e-6, ignore_transform={"output", "step", "start", "stop"}, ) + + +def test_tensor_item_with_slice_1(): + model = Model() + + item_model = TensorItem() + slice_model = Slice(start=0, stop=1, step=None) + + model += slice_model + model += item_model(input="input", index=slice_model.output, output=IOKey("output")) + + input = {"input": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]} + + out_grad = {"output": [[5.0, 6.0]]} + + ref_out = {"output": [[1.0, 2.0]]} + + ref_grad = {"input": [[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]]} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {"input"}, + "inference": False, + "jit": False, + }, + data={}, + params=input, + output_gradients=out_grad, + reference_outputs=ref_out, + reference_gradients=ref_grad, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"step", "start", "stop"}, + ) + + +def test_tensor_item_with_slice_2(): + model = Model() + + item_model = TensorItem() + slice_model = Slice(start=0, stop=2, step=None) + + model += slice_model + model += item_model(input="input", index=slice_model.output, output=IOKey("output")) + + input = {"input": [[[1.0, 2.0]], [[3.0, 4.0]], [[5.0, 6.0]]]} + + out_grad = {"output": [[[3.0, 0.0]], [[2.0, 1.0]]]} + + ref_out = {"output": [[[1.0, 2.0]], [[3.0, 4.0]]]} + + ref_grad = {"input": [[[3.0, 0.0]], [[2.0, 1.0]], [[0.0, 0.0]]]} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {"input"}, + "inference": False, + "jit": False, + }, + data={}, + params=input, + output_gradients=out_grad, + reference_outputs=ref_out, + reference_gradients=ref_grad, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"step", "start", "stop"}, + ) diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index fdf38fa6..132f218d 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -18,9 +18,11 @@ from mithril.framework import Scalar, Tensor from mithril.framework.common import ( NOT_GIVEN, + TBD, BaseKey, ConnectionType, GenericTensorType, + IOKey, MyTensor, ShapeRepr, Uniadic, @@ -35,7 +37,8 @@ Model, PrimitiveModel, Relu, - TensorSlice, + Slice, + TensorItem, Transpose, ) @@ -819,23 +822,24 @@ def test_shape_constraint_counter_12(): def test_shape_constraint_counter_13(): model = Model() - - model_1 = TensorSlice(start=0, stop=2, step=None) + slice_model = Slice(start=0, stop=2, step=None) + model_1 = TensorItem(index=TBD) model_2 = Add() model_3 = Add() model_4 = Add() - - model += model_1 + model += slice_model + model += model_1(input="", index=slice_model.output) model += model_2 model += model_3 model += model_4 ref_dict = make_reference_dict( { - model_1.input: [1, 1], - model_1.start: [1], - model_1.stop: [1], - model_1.step: [1], - model_2.left: [1, 1, 1, 2], + slice_model.start: [], + slice_model.stop: [], + slice_model.step: [], + model_1.input: [1, 2], + model_1.index: [2], + model_2.left: [1, 1, 2, 2], model_2.right: [1, 2], model_3.left: [1, 1, 2, 2], model_3.right: [1, 2], @@ -849,11 +853,12 @@ def test_shape_constraint_counter_13(): model_2.set_shapes({"right": [1]}) ref_dict = make_reference_dict( { - model_1.input: [1, 1], - model_1.start: [1], - model_1.stop: [1], - model_1.step: [1], - model_2.left: [1, 1, 1, 3], + slice_model.start: [], + slice_model.stop: [], + slice_model.step: [], + model_1.input: [1, 2], + model_1.index: [2], + model_2.left: [1, 1, 2, 3], model_2.right: [1, 3], model_3.left: [1, 1, 2, 3], model_3.right: [1, 2], @@ -944,34 +949,55 @@ def test_shape_constraint_counter_14(): def test_shape_constraint_counter_15(): model = Model() - model_1 = TensorSlice(start=1, stop=None, step=None) - model_2 = TensorSlice(start=1, stop=None, step=None) - model_3 = TensorSlice(start=1, stop=None, step=None) - model_4 = TensorSlice(start=1, stop=None, step=None) + slice_1 = Slice(start=TBD, stop=TBD, step=TBD) + slice_2 = Slice(start=TBD, stop=TBD, step=TBD) + slice_3 = Slice(start=TBD, stop=TBD, step=TBD) + slice_4 = Slice(start=TBD, stop=TBD, step=TBD) - model += model_1 - model += model_2 - model += model_3 - model += model_4 + item_model_1 = TensorItem() + item_model_2 = TensorItem() + item_model_3 = TensorItem() + item_model_4 = TensorItem() + + model_1 = Model() + model_1 += slice_1(start="start", stop="stop", step="step") + model_1 += item_model_1(input="input", index=slice_1.output, output=IOKey("output")) + + model_2 = Model() + model_2 += slice_2(start="start", stop="stop", step="step") + model_2 += item_model_2(input="input", index=slice_2.output, output=IOKey("output")) + + model_3 = Model() + model_3 += slice_3(start="start", stop="stop", step="step") + model_3 += item_model_3(input="input", index=slice_3.output, output=IOKey("output")) + + model_4 = Model() + model_4 += slice_4(start="start", stop="stop", step="step") + model_4 += item_model_4(input="input", index=slice_4.output, output=IOKey("output")) + + model += model_1(start=1, stop=None, step=None) + model += model_2(start=1, stop=None, step=None) + model += model_3(start=1, stop=None, step=None) + model += model_4(start=1, stop=None, step=None) ref_dict = make_reference_dict( { - model_1.input: [1, 2], - model_1.start: [2], - model_1.stop: [2], - model_1.step: [2], - model_2.input: [1, 1, 2, 3], - model_2.start: [3], - model_2.stop: [3], - model_2.step: [3], - model_3.input: [1, 1, 3, 3], - model_3.start: [3], - model_3.stop: [3], - model_3.step: [3], - model_4.input: [1, 1, 2, 3], - model_4.start: [2], - model_4.stop: [2], - model_4.step: [2], - model_4.output: [1, 2], + model_1.input: [1, 2], # type: ignore + model_1.start: [], # type: ignore + model_1.stop: [], # type: ignore + model_1.step: [], # type: ignore + model_2.input: [1, 1, 2, 2], # type: ignore + model_2.start: [], # type: ignore + model_2.stop: [], # type: ignore + model_2.step: [], # type: ignore + model_3.input: [1, 1, 2, 2], # type: ignore + model_3.start: [], # type: ignore + model_3.stop: [], # type: ignore + model_3.step: [], # type: ignore + model_4.input: [1, 1, 2, 2], # type: ignore + model_4.start: [], # type: ignore + model_4.stop: [], # type: ignore + model_4.step: [], # type: ignore + model_4.output: [1, 2], # type: ignore } ) assert_constr_counts(ref_dict) @@ -979,23 +1005,23 @@ def test_shape_constraint_counter_15(): model_1.set_shapes({"input": [9]}) ref_dict = make_reference_dict( { - model_1.input: [1], - model_1.start: [], - model_1.stop: [], - model_1.step: [], - model_2.input: [1, 1], - model_2.start: [], - model_2.stop: [], - model_2.step: [], - model_3.input: [1, 1], - model_3.start: [], - model_3.stop: [], - model_3.step: [], - model_4.input: [1, 1], - model_4.start: [], - model_4.stop: [], - model_4.step: [], - model_4.output: [1], + model_1.input: [1], # type: ignore + model_1.start: [], # type: ignore + model_1.stop: [], # type: ignore + model_1.step: [], # type: ignore + model_2.input: [1, 1], # type: ignore + model_2.start: [], # type: ignore + model_2.stop: [], # type: ignore + model_2.step: [], # type: ignore + model_3.input: [1, 1], # type: ignore + model_3.start: [], # type: ignore + model_3.stop: [], # type: ignore + model_3.step: [], # type: ignore + model_4.input: [1, 1], # type: ignore + model_4.start: [], # type: ignore + model_4.stop: [], # type: ignore + model_4.step: [], # type: ignore + model_4.output: [1], # type: ignore } ) assert_constr_counts(ref_dict) diff --git a/tests/scripts/test_primitive_directed.py b/tests/scripts/test_primitive_directed.py index a1e43563..42e069dd 100644 --- a/tests/scripts/test_primitive_directed.py +++ b/tests/scripts/test_primitive_directed.py @@ -30,7 +30,7 @@ def assert_forward( formula_key: str, - expected_result: np.ndarray | int | float | tuple | list, + expected_result: np.ndarray | int | float | tuple | list | slice, args: Any, kwargs: dict[str, Any], backends: list[Backend] = backends, @@ -45,13 +45,16 @@ def assert_forward( } primitive_fn = backend.primitive_function_dict[formula_key] result = primitive_fn(*_args, **_kwargs) - np.testing.assert_allclose( - result, - expected_result, - rtol=1e-14, - atol=1e-14, - err_msg=f"Primitive: {formula_key} failed ", - ) + if not isinstance(expected_result, np.ndarray | tuple | list): + assert result == expected_result + else: + np.testing.assert_allclose( + result, + expected_result, + rtol=1e-14, + atol=1e-14, + err_msg=f"Primitive: {formula_key} failed ", + ) def manul_vjp( @@ -1329,48 +1332,6 @@ def test_transpose_axis_4(): ) -def test_tensor_slice_1(): - start = 0 - stop = 1 - step = None - input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - result = np.array([[1.0, 2.0]]) - - output_grad = np.array([[5.0, 6.0]]) - input_grad = np.array([[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]]) - - assert_forward("tensor_slice", result, (input, start, stop, step), {}) - assert_backward( - "tensor_slice", - (input_grad,), - output_grad, - [0], - {"input": input, "start": start, "stop": stop, "step": step}, - {}, - ) - - -def test_tensor_slice_2(): - start = 0 - stop = 2 - step = None - input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - result = np.array([[1.0, 2.0], [3.0, 4.0]]) - - output_grad = np.array([[3.0, 0.0], [2.0, 1.0]]) - input_grad = np.array([[3.0, 0.0], [2.0, 1.0], [0.0, 0.0]]) - - assert_forward("tensor_slice", result, (input, start, stop, step), {}) - assert_backward( - "tensor_slice", - (input_grad,), - output_grad, - [0], - {"input": input, "start": start, "stop": stop, "step": step}, - {}, - ) - - def test_tanh_1(): input = np.array([[10.0]]) result = np.array([[0.9999999958776928]]) @@ -4592,3 +4553,21 @@ def test_split_4(): {"input": input, "split_size": split_size, "axis": axis}, {}, ) + + +def test_slice_1(): + start = 2 + stop = 4 + step = 1 + result = slice(2, 4, 1) + + assert_forward("primitive_slice", result, (start, stop, step), {}) + + +def test_slice_2(): + start = 3 + stop = None + step = None + result = slice(3, None, None) + + assert_forward("primitive_slice", result, (start, stop, step), {}) diff --git a/tests/scripts/test_recurrent_models.py b/tests/scripts/test_recurrent_models.py index 2f8b1135..66a4f70b 100644 --- a/tests/scripts/test_recurrent_models.py +++ b/tests/scripts/test_recurrent_models.py @@ -22,7 +22,6 @@ from mithril import TorchBackend from mithril.framework.common import NOT_GIVEN, ConnectionType from mithril.models import ( - TBD, AbsoluteError, Add, Buffer, @@ -36,9 +35,10 @@ OneToMany, ScalarItem, Shape, + Slice, Sum, Tanh, - TensorSlice, + TensorItem, TrainModel, ) from mithril.utils.utils import pack_data_into_time_slots @@ -194,8 +194,12 @@ def __init__( shp_model = Shape() scalar_item = ScalarItem() - slice_model_1 = TensorSlice(start=TBD) - slice_model_2 = TensorSlice(stop=TBD) + slice_1 = Slice(stop=None, step=None) + slice_2 = Slice(start=None, step=None) + + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() + mult_model_1 = MatrixMultiply() mult_model_2 = MatrixMultiply() sum_model_1 = Add() @@ -207,12 +211,15 @@ def __init__( self += shp_model(input="input") self += scalar_item(input=shp_model.output, index=0) - self += slice_model_1( - input="prev_hidden", start=scalar_item.output, output=IOKey("hidden_compl") + self += slice_1(start=scalar_item.output) + self += tensor_item_1( + input="prev_hidden", index=slice_1.output, output=IOKey("hidden_compl") ) - self += slice_model_2(input="prev_hidden", stop=scalar_item.output) + + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) self += mult_model_1(left="input", right="w_ih") - self += mult_model_2(left=slice_model_2.output, right="w_hh") + self += mult_model_2(left=tensor_item_2.output, right="w_hh") self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2(left=sum_model_1.output, right="bias_hh") self += sum_model_3( @@ -310,8 +317,13 @@ def __init__( shp_model = Shape() scalar_item = ScalarItem() - slice_model_1 = TensorSlice(start=TBD) - slice_model_2 = TensorSlice(stop=TBD) + + slice_1 = Slice(stop=None, step=None) + slice_2 = Slice(start=None, step=None) + + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() + mult_model_1 = MatrixMultiply() mult_model_2 = MatrixMultiply() sum_model_1 = Add() @@ -321,12 +333,14 @@ def __init__( self += shp_model(input="input") self += scalar_item(input=shp_model.output, index=0) - self += slice_model_1( - input="prev_hidden", start=scalar_item.output, output=IOKey("hidden_compl") + self += slice_1(start=scalar_item.output) + self += tensor_item_1( + input="prev_hidden", index=slice_1.output, output=IOKey("hidden_compl") ) - self += slice_model_2(input="prev_hidden", stop=scalar_item.output) + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) self += mult_model_1(left="input", right="w_ih") - self += mult_model_2(left=slice_model_2.output, right="w_hh") + self += mult_model_2(left=tensor_item_2.output, right="w_hh") self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2(left=sum_model_1.output, right="bias_hh") self += sum_model_3(