From 5f9193c0c8bd9daf306fbd8c4300c1758539a511 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 14 Nov 2024 10:16:50 +0300 Subject: [PATCH 1/7] feat: Now People will be assigned based on applied labels, github actions bot will request a change if PR title does not follow Conventional format --- .github/workflows/check-pr-title.yml | 8 ++++++- .github/workflows/pr-label-assign.yml | 31 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/pr-label-assign.yml diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml index e10a3584..c04ce31d 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check-pr-title.yml @@ -7,14 +7,20 @@ on: jobs: review_pr_title: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} steps: + - name: Check out the repository + uses: actions/checkout@v4 - name: Check PR title format id: check_title run: | TITLE="${{ github.event.pull_request.title }}" if [[ ! "$TITLE" =~ ^(feat|fix|docs|chore|perf|test|refactor): ]]; then echo "::error::PR title does not follow the convention 'type: description'. Please review the title." + gh pr review ${{ github.event.pull_request.number }} -r -b "This PR does not satisfies PR title format. Title format should be same as one of the conventional commits" exit 1 else echo "PR title meets the required format." - fi \ No newline at end of file + fi + \ No newline at end of file diff --git a/.github/workflows/pr-label-assign.yml b/.github/workflows/pr-label-assign.yml new file mode 100644 index 00000000..909d0894 --- /dev/null +++ b/.github/workflows/pr-label-assign.yml @@ -0,0 +1,31 @@ +# This wrokflow is triggered when a PR is labeled. +# It assigns reviewers based on the label applied to the PR. + +name: PR Label Assigner + +on: + pull_request_target: + types: + - labeled + +jobs: + pr-labeler: + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + steps: + + - name: Check out the repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install PyYAML + run: pip install pyyaml + + - name: Assign Reviewers # Assign responsible people based on labels + run: | + python .github/scripts/assign_reviewers.py "${{ github.event.pull_request.number }}" "${{ github.event.label.name }}" \ No newline at end of file From 8309838d5e5258daba5fd5ba585c737362283b82 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 14 Nov 2024 12:37:43 +0300 Subject: [PATCH 2/7] feat: Now actions bot requests a change in the case of tests failed --- .github/workflows/ci-test.yaml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-test.yaml b/.github/workflows/ci-test.yaml index efdc296c..fa0e864f 100644 --- a/.github/workflows/ci-test.yaml +++ b/.github/workflows/ci-test.yaml @@ -41,4 +41,19 @@ jobs: - name: Upload results to Codecov uses: codecov/codecov-action@v4 with: - token: ${{ secrets.CODECOV_TOKEN }} + token: ${{ secrets.CODECOV_TOKEN }} + + on_failure: + needs: build + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + if: ${{ failure() }} + steps: + - name: Check out the repository + uses: actions/checkout@v4 + - name: review_pr + id: review-pr + run: | + gh pr review ${{ github.event.pull_request.number }} -r -b "Tests are failed. Please review the PR." + exit 1 \ No newline at end of file From 5916d8f74246f27b0d5975b76f0ef77e97d9d27e Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 21 Nov 2024 15:09:05 +0300 Subject: [PATCH 3/7] fix: bug fix attempt 1 --- .github/workflows/check-pr-title.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml index c04ce31d..0dae76cd 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check-pr-title.yml @@ -15,7 +15,7 @@ jobs: - name: Check PR title format id: check_title run: | - TITLE="${{ github.event.pull_request.title }}" + TITLE='${{ github.event.pull_request.title }}' if [[ ! "$TITLE" =~ ^(feat|fix|docs|chore|perf|test|refactor): ]]; then echo "::error::PR title does not follow the convention 'type: description'. Please review the title." gh pr review ${{ github.event.pull_request.number }} -r -b "This PR does not satisfies PR title format. Title format should be same as one of the conventional commits" From 8e4741439dfe8cbe46719102f8955d34575a2fe7 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Fri, 20 Dec 2024 16:00:02 +0300 Subject: [PATCH 4/7] slice model is added to mithril --- .../with_autograd/common_primitives.py | 5 + .../backends/with_autograd/jax_backend/ops.py | 2 + .../backends/with_autograd/mlx_backend/ops.py | 2 + .../with_autograd/torch_backend/ops.py | 2 + .../with_manualgrad/common_primitives.py | 7 + .../with_manualgrad/numpy_backend/ops.py | 2 + mithril/framework/constraints.py | 35 +++ .../framework/logical/essential_primitives.py | 39 ++++ mithril/framework/physical/model.py | 5 +- tests/scripts/test_all_models.py | 217 +++++++++++++++++- tests/scripts/test_constraints.py | 137 ++++++++++- 11 files changed, 448 insertions(+), 5 deletions(-) diff --git a/mithril/backends/with_autograd/common_primitives.py b/mithril/backends/with_autograd/common_primitives.py index cd6892a6..86b2fd18 100644 --- a/mithril/backends/with_autograd/common_primitives.py +++ b/mithril/backends/with_autograd/common_primitives.py @@ -42,6 +42,7 @@ "swapaxes", "square", "tensor_slice", + "primitive_slice", "buffer", "permute_tensor", "reshape", @@ -183,6 +184,10 @@ def tensor_item(input: DataType, index: int | slice | tuple[int | slice, ...]): return input[index] +def primitive_slice(start: int | None, stop: int | None, step: int | None): + return slice(start, stop, step) + + def length(input: DataType) -> int: return len(input) diff --git a/mithril/backends/with_autograd/jax_backend/ops.py b/mithril/backends/with_autograd/jax_backend/ops.py index 321a99c3..50e5ec78 100644 --- a/mithril/backends/with_autograd/jax_backend/ops.py +++ b/mithril/backends/with_autograd/jax_backend/ops.py @@ -51,6 +51,7 @@ permute_tensor, power, primitive_embedding, + primitive_slice, reshape, scalar_item, sequence_slice, @@ -197,6 +198,7 @@ "item", "scalar_item", "tensor_item", + "primitive_slice", "sequence_slice", "union", "length", diff --git a/mithril/backends/with_autograd/mlx_backend/ops.py b/mithril/backends/with_autograd/mlx_backend/ops.py index a40715cc..f6954982 100644 --- a/mithril/backends/with_autograd/mlx_backend/ops.py +++ b/mithril/backends/with_autograd/mlx_backend/ops.py @@ -50,6 +50,7 @@ permute_tensor, power, primitive_embedding, + primitive_slice, reshape, scalar_item, sequence_slice, @@ -170,6 +171,7 @@ "item", "scalar_item", "tensor_item", + "primitive_slice", "sequence_slice", "union", "length", diff --git a/mithril/backends/with_autograd/torch_backend/ops.py b/mithril/backends/with_autograd/torch_backend/ops.py index f9d454cb..17501e3c 100644 --- a/mithril/backends/with_autograd/torch_backend/ops.py +++ b/mithril/backends/with_autograd/torch_backend/ops.py @@ -51,6 +51,7 @@ permute_tensor, power, primitive_embedding, + primitive_slice, reshape, scalar_item, sequence_slice, @@ -187,6 +188,7 @@ "item", "scalar_item", "tensor_item", + "primitive_slice", "sequence_slice", "union", "length", diff --git a/mithril/backends/with_manualgrad/common_primitives.py b/mithril/backends/with_manualgrad/common_primitives.py index 4dab0aed..5910ba7d 100644 --- a/mithril/backends/with_manualgrad/common_primitives.py +++ b/mithril/backends/with_manualgrad/common_primitives.py @@ -50,6 +50,7 @@ "item", "scalar_item", "tensor_item", + "primitive_slice", "swapaxes", "sequence_slice", "union", @@ -295,6 +296,12 @@ def tensor_item( return input[index] +def primitive_slice( + start: int | None, stop: int | None, step: int | None, cache: CacheType = None +): + return slice(start, stop, step) + + def swapaxes( input: DataType, axis1: int, axis2: int, *, cache: CacheType = None ) -> DataType: diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index 5b3e5566..2f0f2f35 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -50,6 +50,7 @@ padding_converter_2d, permute_tensor, power, + primitive_slice, reshape, scalar_item, sequence_slice, @@ -187,6 +188,7 @@ "item", "scalar_item", "tensor_item", + "primitive_slice", "swapaxes", "sequence_slice", "union", diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index d0db0e78..84cb2f61 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -60,6 +60,7 @@ "floor_divide_type_constraint", "scalar_slice_type_constraint", "scalar_item_type_constraint", + "slice_constraints", "bcast", "bcast_matrix_mult", "sliding_window_1d_constraints", @@ -549,6 +550,40 @@ def scalar_item_type_constraint(output: Scalar, input: Scalar, index: Scalar): return status, updates +def slice_constraints(output: Scalar, start: Scalar, stop: Scalar, step: Scalar): + updates = Updates() + output_value = output.value + start_value = start.value + stop_value = stop.value + step_value = step.value + status = False + + assert isinstance(start_value, ToBeDetermined | int | None) + assert isinstance(stop_value, ToBeDetermined | int | None) + assert isinstance(step_value, ToBeDetermined | int | None) + assert isinstance(output_value, ToBeDetermined | slice) + + if ( + not isinstance(start_value, ToBeDetermined) + and not isinstance(step_value, ToBeDetermined) + and not isinstance(stop_value, ToBeDetermined) + ): + updates |= output.set_value(slice(start_value, stop_value, step_value)) + status = True + + elif not isinstance(output_value, ToBeDetermined): + start_val = output_value.start + stop_val = output_value.stop + step_val = output_value.step + + updates |= start.set_value(start_val) + updates |= stop.set_value(stop_val) + updates |= step.set_value(step_val) + status = True + + return status, updates + + def tensor_to_list_type_constraint(output: Scalar, input: Tensor): status = not is_union(output._type) updates = Updates() diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 7afda959..e4a07e1b 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -44,6 +44,7 @@ scalar_slice_type_constraint, shape_constraints, size_constraints, + slice_constraints, split_constraints, tensor_item_constraints, tensor_slice_constraints, @@ -106,6 +107,7 @@ "Transpose", "Sqrt", "Split", + "Slice", ] ConstantType = float | int | Constant @@ -1493,3 +1495,40 @@ def __call__( # type: ignore[override] return super().__call__( input=input, split_size=split_size, axis=axis, output=output ) + + +class Slice(PrimitiveModel): + start: Connection + stop: Connection + step: Connection + output: Connection + + def __init__( + self, + start: int | None | ToBeDetermined = 0, + stop: int | None | ToBeDetermined = None, + step: int | None | ToBeDetermined = None, + name: str | None = None, + ): + super().__init__( + formula_key="primitive_slice", + name=name, + output=IOKey(type=slice), + start=IOKey(type=int | None, value=start), + stop=IOKey(type=int | None, value=stop), + step=IOKey(type=int | None, value=step), + ) + self.factory_inputs = {"start": start, "stop": stop, "step": step} + + self._set_constraint( + fn=slice_constraints, keys=["output", "start", "stop", "step"] + ) + + def __call__( # type: ignore[override] + self, + start: ConnectionType = NOT_GIVEN, + stop: ConnectionType = NOT_GIVEN, + step: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(start=start, stop=stop, step=step, output=output) diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index fa2efc90..1a0b5d3d 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -103,12 +103,15 @@ def __init__( # NOTE: Do not set default value if it is given in constant_keys. value = (value, NOT_GIVEN)[key in constant_keys] default_val = model.conns.get_data(key).value - if value is NOT_GIVEN and default_val is TBD: + if (value is NOT_GIVEN and default_val is TBD) or ( + key in model.output_keys + ): # Non-valued connections are only named with their key names. model_keys[key] = key else: val = default_val if default_val is not TBD else value model_keys[key] = IOKey(key, val) # type: ignore + model = Model() + model(**model_keys) self.backend: Backend[DataType] = backend diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index ccdf4e98..6ff9ca60 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -61,6 +61,7 @@ Shape, SiLU, Size, + Slice, SquaredError, Squeeze, ToList, @@ -165,7 +166,9 @@ def compile_and_compare( # We may not include any reference value for some keys for a certain test. # So we don't assert set(outputs.keys()) == set(reference_outputs) since # outputs can have some keys which reference_outputs does not include. - if out is not None: + if not isinstance(out, backend.get_backend_array_type()): + assert v == out + elif out is not None: if tolerance is not None and relative_tolerance is not None: assert ( all(backend.flatten(backend.abs(v - out) < tolerance)) @@ -1559,7 +1562,7 @@ def test_size_1(): reference_outputs = {"output": 4} compile_and_compare( model=model, - compile_kwargs={"constant_keys": statics, "inference": True}, + compile_kwargs={"constant_keys": statics, "inference": True, "jit": False}, data={}, params={}, output_gradients={}, @@ -1577,7 +1580,7 @@ def test_size_2(): reference_outputs = {"output": 24} compile_and_compare( model=model, - compile_kwargs={"constant_keys": statics, "inference": True}, + compile_kwargs={"constant_keys": statics, "inference": True, "jit": False}, data={}, params={}, output_gradients={}, @@ -1887,6 +1890,7 @@ def test_index_1(): compile_kwargs={ "data_keys": {"input"}, "inference": True, + "jit": False, }, data=data, params={}, @@ -1912,6 +1916,7 @@ def test_index_2(): compile_kwargs={ "data_keys": {"input"}, "inference": True, + "jit": False, }, data=data, params={}, @@ -3257,3 +3262,209 @@ def test_groupnorm_4(): assert_shapes=False, tolerances=1e-6, ) + + +def test_slice_all_values_given_in_init(): + """ + Given in __init__: 'start', 'stop', 'step' + Given in data: ... + given as compile constant: ... + """ + start = 3 + stop = 10 + step = 7 + model = Slice(start, stop, step) + pm = mithril.compile(model, JaxBackend(), inference=True, jit=False) + pm.evaluate() + reference_outputs = {"output": slice(start, stop, step)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output"}, + ) + + +def test_slice_given_in_compile_data(): + """ + Given in __init__: 'start', 'stop' + Given in data: 'step' + given as compile constant: ... + """ + start = 1 + stop = 12 + model = Slice(start, stop, step=TBD) + reference_outputs = {"output": slice(1, 12, 2)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={"step": 2}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output", "step"}, + ) + + +def test_slice_given_in_compile_constant(): + """ + Given in __init__: 'start', 'stop' + Given in data: ... + given as compile constant: 'step' + """ + start = 1 + stop = 12 + model = Slice(start, stop, step=TBD) + reference_outputs = {"output": slice(1, 12, 2)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {"step": 2}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output", "step"}, + ) + + +def test_slice_all_keys_given_as_constants(): + """ + Given in __init__: ... + Given in data: ... + given as compile constant: 'start', 'stop', 'step' + """ + model = Slice(start=TBD, stop=TBD, step=TBD) + reference_outputs = {"output": slice(1, 12, 2)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {"start": 1, "stop": 12, "step": 2}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output", "step", "start", "stop"}, + ) + + +def test_slice_all_keys_given_in_data(): + """ + Given in __init__: ... + Given in data: 'start', 'stop', 'step' + given as compile constant: ... + """ + model = Slice(start=TBD, stop=TBD, step=TBD) + reference_outputs = {"output": slice(1, 12, 2)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={"start": 1, "stop": 12, "step": 2}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output", "step", "start", "stop"}, + ) + + +def test_slice_all_keys_given_in_constant_and_data(): + """ + Given in __init__: ... + Given in data: 'start, stop' + given as compile constant: 'step' + """ + model = Slice(start=TBD, stop=TBD, step=TBD) + reference_outputs = {"output": slice(1, 12, 2)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {"step": 2}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={"start": 1, "stop": 12}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output", "step", "start", "stop"}, + ) + + +def test_slice_all_keys_given_all_three_parts(): + """ + Given in __init__: 'start' + Given in data: 'stop' + given as compile constant: 'step' + """ + + model = Slice(start=1, stop=TBD, step=TBD) + reference_outputs = {"output": slice(1, 12, 2)} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {"step": 2}, + "trainable_keys": {}, + "inference": True, + "jit": False, + }, + data={"stop": 12}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"output", "step", "start", "stop"}, + ) diff --git a/tests/scripts/test_constraints.py b/tests/scripts/test_constraints.py index e5ff1758..d5589af7 100644 --- a/tests/scripts/test_constraints.py +++ b/tests/scripts/test_constraints.py @@ -60,6 +60,7 @@ scalar_slice_type_constraint, shape_constraints, size_constraints, + slice_constraints, sliding_window_2d_constraints, split_constraints, squeeze_constraints, @@ -250,7 +251,9 @@ def assert_value_results( data: dict[str, Tensor | Scalar], ref_results: dict[str, Any] ) -> None: for key, value in ref_results.items(): - if isinstance(value, int | float | bool | tuple | list | str | ToBeDetermined): + if isinstance( + value, int | float | bool | tuple | list | str | ToBeDetermined | slice + ): assert data[key].value == value else: # If value is a tensor of any supported backend. @@ -7800,3 +7803,135 @@ def test_bcast_with_var_possibles_4(): assert_constraint_results( shapes, assignments, final_shapes, final_assignments, bcast, False, {"left"} ) + + +def test_slice_given_input(): + shapes: dict[str, list[int | str | tuple]] = {} + final_shapes: dict[str, list[int | str | tuple]] = { + "output": [], + "start": [], + "stop": [], + "step": [], + } + scalar_info = { + "output": Scalar(possible_types=slice, value=TBD), + "start": Scalar(possible_types=int | None, value=1), + "stop": Scalar(possible_types=int | None, value=3), + "step": Scalar(possible_types=int | None, value=5), + } + final_values = { + "output": slice(1, 3, 5), + "start": 1, + "stop": 3, + "step": 5, + } + assert_constraint_results( + shapes, + {}, + final_shapes, + {}, + slice_constraints, + True, + {"output"}, + scalar_info, + final_values, + ) + + +def test_slice_given_missing_input(): + shapes: dict[str, list[int | str | tuple]] = {} + final_shapes: dict[str, list[int | str | tuple]] = { + "output": [], + "start": [], + "stop": [], + "step": [], + } + scalar_info = { + "output": Scalar(possible_types=slice, value=TBD), + "start": Scalar(possible_types=int | None, value=1), + "stop": Scalar(possible_types=int | None, value=3), + "step": Scalar(possible_types=int | None, value=TBD), + } + final_values = { + "output": TBD, + "start": 1, + "stop": 3, + "step": TBD, + } + assert_constraint_results( + shapes, + {}, + final_shapes, + {}, + slice_constraints, + False, + set(), + scalar_info, + final_values, + ) + + +def test_slice_given_output(): + shapes: dict[str, list[int | str | tuple]] = {} + final_shapes: dict[str, list[int | str | tuple]] = { + "output": [], + "start": [], + "stop": [], + "step": [], + } + scalar_info = { + "output": Scalar(possible_types=slice, value=slice(1, 3, 5)), + "start": Scalar(possible_types=int | None, value=1), + "stop": Scalar(possible_types=int | None, value=3), + "step": Scalar(possible_types=int | None, value=TBD), + } + final_values = { + "output": slice(1, 3, 5), + "start": 1, + "stop": 3, + "step": 5, + } + assert_constraint_results( + shapes, + {}, + final_shapes, + {}, + slice_constraints, + True, + {"step"}, + scalar_info, + final_values, + ) + + +def test_slice_given_output_missing_all_inputs(): + shapes: dict[str, list[int | str | tuple]] = {} + final_shapes: dict[str, list[int | str | tuple]] = { + "output": [], + "start": [], + "stop": [], + "step": [], + } + scalar_info = { + "output": Scalar(possible_types=slice, value=slice(1, 3, 5)), + "start": Scalar(possible_types=int | None, value=TBD), + "stop": Scalar(possible_types=int | None, value=TBD), + "step": Scalar(possible_types=int | None, value=TBD), + } + final_values = { + "output": slice(1, 3, 5), + "start": 1, + "stop": 3, + "step": 5, + } + assert_constraint_results( + shapes, + {}, + final_shapes, + {}, + slice_constraints, + True, + {"start", "stop", "step"}, + scalar_info, + final_values, + ) From 85ca767049a1d742327855452d147b42ca12870d Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 23 Dec 2024 13:12:48 +0300 Subject: [PATCH 5/7] TensorSlice model is removed --- .../with_autograd/common_primitives.py | 7 - .../backends/with_autograd/jax_backend/ops.py | 2 - .../backends/with_autograd/mlx_backend/ops.py | 2 - .../with_autograd/torch_backend/ops.py | 2 - .../with_manualgrad/common_primitives.py | 11 -- .../with_manualgrad/numpy_backend/ops.py | 2 - .../with_manualgrad/numpy_backend/ops_grad.py | 28 ---- mithril/framework/constraints.py | 31 ---- .../framework/logical/essential_primitives.py | 138 ++++++----------- mithril/framework/logical/model.py | 2 - mithril/models/models.py | 70 +++++---- tests/json_files/models_directed_test.json | 78 ---------- .../json_files/primitives_directed_test.json | 61 -------- .../randomized_model_tests_all_backends.json | 16 -- tests/scripts/test_constr_counter.py | 139 +++++++++++------- tests/scripts/test_primitive_directed.py | 42 ------ tests/scripts/test_recurrent_models.py | 41 ++++-- 17 files changed, 199 insertions(+), 473 deletions(-) diff --git a/mithril/backends/with_autograd/common_primitives.py b/mithril/backends/with_autograd/common_primitives.py index 86b2fd18..b9a63c42 100644 --- a/mithril/backends/with_autograd/common_primitives.py +++ b/mithril/backends/with_autograd/common_primitives.py @@ -41,7 +41,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "primitive_slice", "buffer", "permute_tensor", @@ -158,12 +157,6 @@ def square(input: DataType): return input * input -def tensor_slice( - input: DataType, start: int | None, stop: int | None, step: int | None -): - return input[start:stop:step] - - def buffer(input: DataType): return input diff --git a/mithril/backends/with_autograd/jax_backend/ops.py b/mithril/backends/with_autograd/jax_backend/ops.py index 50e5ec78..f11672ab 100644 --- a/mithril/backends/with_autograd/jax_backend/ops.py +++ b/mithril/backends/with_autograd/jax_backend/ops.py @@ -63,7 +63,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, transpose, @@ -191,7 +190,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_autograd/mlx_backend/ops.py b/mithril/backends/with_autograd/mlx_backend/ops.py index f6954982..650192a4 100644 --- a/mithril/backends/with_autograd/mlx_backend/ops.py +++ b/mithril/backends/with_autograd/mlx_backend/ops.py @@ -62,7 +62,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, transpose, @@ -164,7 +163,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_autograd/torch_backend/ops.py b/mithril/backends/with_autograd/torch_backend/ops.py index 17501e3c..5ea63444 100644 --- a/mithril/backends/with_autograd/torch_backend/ops.py +++ b/mithril/backends/with_autograd/torch_backend/ops.py @@ -63,7 +63,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, tuple_converter, @@ -181,7 +180,6 @@ "transpose", "swapaxes", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_manualgrad/common_primitives.py b/mithril/backends/with_manualgrad/common_primitives.py index 5910ba7d..34946316 100644 --- a/mithril/backends/with_manualgrad/common_primitives.py +++ b/mithril/backends/with_manualgrad/common_primitives.py @@ -43,7 +43,6 @@ "squared_error", "transpose", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", @@ -165,16 +164,6 @@ def square(input: DataType, cache: CacheType = None): return input * input -def tensor_slice( - input: DataType, - start: int | None, - stop: int | None, - step: int | None, - cache: CacheType = None, -): - return input[start:stop:step] - - def buffer(input: DataType, cache: CacheType = None): return input diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index 2f0f2f35..ee36d315 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -62,7 +62,6 @@ subtract, swapaxes, tensor_item, - tensor_slice, to_list, to_tuple, transpose, @@ -181,7 +180,6 @@ "squared_error", "transpose", "square", - "tensor_slice", "buffer", "permute_tensor", "reshape", diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py index 1c46a5f9..c51c4314 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py @@ -75,7 +75,6 @@ "softplus_grad", "gelu_grad", "stop_gradient_grad", - "tensor_slice_grad", "tensor_item_grad", "permute_tensor_grad", "transpose_grad", @@ -804,33 +803,6 @@ def stop_gradient_grad( return np.zeros_like(output_gradient) -# def tensor_slice_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# verify_shapes(inputs, idx) -# input1, input2 = inputs -# if idx == 0: -# grad = np.zeros_like(input1) -# grad[:input2.shape[0], ...] = output_gradient -# return grad -# elif idx == 1: -# return np.zeros_like(input2) - - -def tensor_slice_grad( - output_gradient: np.ndarray, - cache: CacheType, - idx: int, - *inputs: np.ndarray, -) -> np.ndarray: - verify_shapes(inputs, idx, non_differentiables=[1, 2, 3, 4]) - input, start, stop, step = inputs - grad = np.zeros_like(input) - grad[start:stop:step] = output_gradient - return grad - - def tensor_item_grad( output_gradient: np.ndarray, cache: CacheType, diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index 84cb2f61..391db800 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -87,7 +87,6 @@ "scalar_item_constraints", "to_tuple_constraints", "tensor_item_constraints", - "tensor_slice_constraints", "tensor_to_list_type_constraint", "reduce_type_constraint", "type_constraints", @@ -3419,36 +3418,6 @@ def tensor_item_constraint_helper( return input_unis, output_unis, status, current_index -def tensor_slice_constraints( - output: Tensor, input: Tensor, start: Scalar, stop: Scalar, step: Scalar -) -> ConstrainResultType: - assert output._temp_shape is not None, "Output shape of TensorSlice is not set!" - assert input._temp_shape is not None, "Input shape of TensorSlice is not set!" - output_shape: ShapeRepr = output._temp_shape - input_shape: ShapeRepr = input._temp_shape - updated_symbols = Updates() - status = False - - if input_shape.prefix and output_shape.prefix: - in_uni, out_uni = input_shape[0], output_shape[0] - if in_uni.value is not None and out_uni.value is not None: - status = True - else: - if ( - start.value is not TBD - and stop.value is not TBD - and step.value is not TBD - and in_uni.value is not None - ): - slc = slice(start.value, stop.value, step.value) - out_val = len(list(range(in_uni.value))[slc]) - out_uni.set_value(out_val) - updated_symbols.add(out_uni) - status = True - - return status, updated_symbols - - def split_constraints(output: Tensor, input: Tensor, split_size: Scalar, axis: Scalar): status = False split_size_val = split_size.value diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index e4a07e1b..1ad55ca6 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -47,7 +47,6 @@ slice_constraints, split_constraints, tensor_item_constraints, - tensor_slice_constraints, tensor_to_list_constraints, tensor_to_list_type_constraint, to_list_constraints, @@ -100,7 +99,6 @@ "ShiftLeft", "ShiftRight", "TensorItem", - "TensorSlice", "ArgMax", "ArgMin", "Cast", @@ -603,54 +601,6 @@ def __call__( # type: ignore[override] ) -class TensorSlice(PrimitiveModel): - input: Connection - start: Connection - stop: Connection - step: Connection - output: Connection - - def __init__( - self, - name: str | None = None, - start: int | None | ToBeDetermined = None, - stop: int | None | ToBeDetermined = None, - step: int | None | ToBeDetermined = None, - input: TensorValueType | ToBeDetermined = TBD, - ) -> None: - self.factory_args = {"start": start, "stop": stop, "step": step} - super().__init__( - formula_key="tensor_slice", - name=name, - output=IOKey(shape=["a", ("Var1", ...)], type=GenericTensorType), - input=IOKey(shape=["b", ("Var1", ...)], type=GenericTensorType), - start=IOKey(type=int | None, value=start), - stop=IOKey(type=int | None, value=stop), - step=IOKey(type=int | None, value=step), - ) - self.factory_inputs = {"input": input} - - self._set_constraint( - fn=tensor_slice_constraints, - keys=[PrimitiveModel.output_key, "input", "start", "stop", "step"], - ) - self._set_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] - ) - - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - start: ConnectionType = NOT_GIVEN, - stop: ConnectionType = NOT_GIVEN, - step: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__( - input=input, start=start, stop=stop, step=step, output=output - ) - - class Item(PrimitiveModel): input: Connection output: Connection @@ -715,50 +665,6 @@ def __call__( # type: ignore[override] return super().__call__(input=input, index=index, output=output) -class TensorItem(PrimitiveModel): - input: Connection - index: Connection - output: Connection - - def __init__( - self, - name: str | None = None, - index: int | ToBeDetermined = TBD, - input: TensorValueType | ToBeDetermined = TBD, - ) -> None: - super().__init__( - formula_key="tensor_item", - name=name, - output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), - input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), - index=IOKey( - type=int - | slice - | EllipsisType - | None - | tuple[int | slice | EllipsisType | None, ...], - value=index, - ), - ) - self.factory_inputs = {"input": input, "index": index} - - self._set_constraint( - fn=tensor_item_constraints, - keys=[PrimitiveModel.output_key, "input", "index"], - ) - self._set_constraint( - fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] - ) - - def __call__( # type: ignore[override] - self, - input: ConnectionType = NOT_GIVEN, - index: ConnectionType = NOT_GIVEN, - output: ConnectionType = NOT_GIVEN, - ) -> ExtendInfo: - return super().__call__(input=input, index=index, output=output) - - class ToTensor(PrimitiveModel): input: Connection output: Connection @@ -1532,3 +1438,47 @@ def __call__( # type: ignore[override] output: ConnectionType = NOT_GIVEN, ) -> ExtendInfo: return super().__call__(start=start, stop=stop, step=step, output=output) + + +class TensorItem(PrimitiveModel): + input: Connection + index: Connection + output: Connection + + def __init__( + self, + name: str | None = None, + index: int | ToBeDetermined = TBD, + input: TensorValueType | ToBeDetermined = TBD, + ) -> None: + super().__init__( + formula_key="tensor_item", + name=name, + output=IOKey(shape=[("Var2", ...)], type=GenericTensorType), + input=IOKey(shape=[("Var1", ...)], type=GenericTensorType), + index=IOKey( + type=int + | slice + | EllipsisType + | None + | tuple[int | slice | EllipsisType | None, ...], + value=index, + ), + ) + self.factory_inputs = {"input": input, "index": index} + + self._set_constraint( + fn=tensor_item_constraints, + keys=[PrimitiveModel.output_key, "input", "index"], + ) + self._set_constraint( + fn=general_tensor_type_constraint, keys=[PrimitiveModel.output_key, "input"] + ) + + def __call__( # type: ignore[override] + self, + input: ConnectionType = NOT_GIVEN, + index: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(input=input, index=index, output=output) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index d13fb87e..0811028f 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -84,7 +84,6 @@ Subtract, Sum, TensorItem, - TensorSlice, TensorToList, ToList, ToTensor, @@ -139,7 +138,6 @@ coercion_table: dict[tuple[str, type[Tensor] | type[Scalar]], type[PrimitiveModel]] = { ("item", Tensor): TensorItem, ("item", Scalar): ScalarItem, - ("slice", Tensor): TensorSlice, ("slice", Scalar): PrimitiveSlice, } diff --git a/mithril/models/models.py b/mithril/models/models.py index 0a59b7ec..e3e9a843 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -50,10 +50,11 @@ ScalarItem, Shape, Size, + Slice, Sqrt, Subtract, Sum, - TensorSlice, + TensorItem, Transpose, Variance, ) @@ -1369,7 +1370,10 @@ def __init__( shape = Shape() scalar_item = ScalarItem() - slice_model = TensorSlice(stop=TBD) + slice_1 = Slice(start=TBD) + slice_2 = Slice(stop=TBD) + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() mult_model_1 = Linear(use_bias=False) mult_model_2 = Linear(use_bias=False) mult_model_3 = Linear(use_bias=False) @@ -1378,13 +1382,15 @@ def __init__( self += shape(input="input") self += scalar_item(input=shape.output, index=0) - self += TensorSlice(start=TBD)( + self += slice_1(start=scalar_item.output) + self += tensor_item_1( input="prev_hidden", - start=scalar_item.output, + index=slice_1.output, output=IOKey(name="hidden_compl"), ) - self += slice_model(input="prev_hidden", stop=scalar_item.output) - self += mult_model_1(input=slice_model.output, weight="w_hh") + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) + self += mult_model_1(input=tensor_item_2.output, weight="w_hh") self += mult_model_2(input="input", weight="w_ih") self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2(left=sum_model_1.output, right="bias_h") @@ -1501,40 +1507,52 @@ def __init__( cell_body = LSTMCellBody() shape_model = Shape() scalar_item = ScalarItem() - slice_model_1 = TensorSlice(stop=TBD) - slice_model_2 = TensorSlice(stop=TBD) - slice_model_3 = TensorSlice(start=TBD) - slice_model_4 = TensorSlice(stop=TBD) + + slice_1 = Slice(stop=TBD) + slice_2 = Slice(stop=TBD) + slice_3 = Slice(start=TBD) + slice_4 = Slice(stop=TBD) + slice_5 = Slice(start=TBD) + + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() + tensor_item_3 = TensorItem() + tensor_item_4 = TensorItem() + tensor_item_5 = TensorItem() self += shape_model(input="input") self += scalar_item(input=shape_model.output, index=0) # Forget gate processes. - self += slice_model_1(input="prev_cell", stop=scalar_item.output) - self += slice_model_2(input="prev_hidden", stop=scalar_item.output) + self += slice_1(stop=scalar_item.output) + self += tensor_item_1(input="prev_cell", index=slice_1.output) + + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) body_kwargs: dict[str, ConnectionType] = { key: key for key in cell_body._input_keys if key[0] != "$" } - body_kwargs["prev_cell"] = slice_model_1.output - body_kwargs["prev_hidden"] = slice_model_2.output + body_kwargs["prev_cell"] = tensor_item_1.output + body_kwargs["prev_hidden"] = tensor_item_2.output self += cell_body(**body_kwargs) - self += slice_model_3( - input=cell_body.output, - start=scalar_item.output, - output=IOKey(name="hidden"), + self += slice_3(start=scalar_item.output) + self += tensor_item_3( + input=cell_body.output, index=slice_3.output, output=IOKey(name="hidden") ) - self += slice_model_4( - input=cell_body.output, stop=scalar_item.output, output=IOKey(name="cell") + self += slice_4(stop=scalar_item.output) + self += tensor_item_4( + input=cell_body.output, index=slice_4.output, output=IOKey(name="cell") ) # Slice complement process. - self += TensorSlice(start=TBD)( + self += slice_5(start=scalar_item.output) + self += tensor_item_5( input="prev_hidden", - start=scalar_item.output, + index=slice_5.output, output=IOKey(name="hidden_compl"), ) # Final output. @@ -1795,7 +1813,8 @@ def __init__( # current time step. shape_model = Shape() item_model = ScalarItem() - slice_model = TensorSlice(stop=TBD) + slice_model = Slice(stop=TBD) + tensor_item = TensorItem() self += shape_model(input=f"target{idx}") self += item_model(input=shape_model.output, index=0) @@ -1811,9 +1830,10 @@ def __init__( # of previous time step as inputs to the current time step. slice_input_1 = getattr(prev_cell, prev_cell.out_key) - self += slice_model(input=slice_input_1, stop=item_model.output) + self += slice_model(stop=item_model.output) + self += tensor_item(input=slice_input_1, index=slice_model.output) - input_kwargs = {"input": slice_model.output} + input_kwargs = {"input": tensor_item.output} output_kwargs = {cell_type.out_key: IOKey(name=f"output{idx}")} self += current_cell( diff --git a/tests/json_files/models_directed_test.json b/tests/json_files/models_directed_test.json index 30f062bc..8aab8739 100644 --- a/tests/json_files/models_directed_test.json +++ b/tests/json_files/models_directed_test.json @@ -4120,84 +4120,6 @@ } }, - "test_tensor_slice_1": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 1, - "step": null - } - }, - "inputs": { - "input": [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0] - ] - }, - - "output_grads": { - "output": [ - [5.0, 6.0] - ] - }, - "results": { - "eval": { - "output": [ - [1.0, 2.0] - ] - }, - "grad": { - "input": [[5.0, 6.0], - [0.0, 0.0], - [0.0, 0.0]] - } - } - }, - - "test_tensor_slice_2": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 2, - "step": null - } - }, - "inputs": { - "input": [ - [[1.0, 2.0]], - [[3.0, 4.0]], - [[5.0, 6.0]] - ] - }, - - "output_grads": { - "output": [ - [[3.0, 0.0]], - [[2.0, 1.0]] - ] - }, - "results": { - "eval": { - "output": [ - [[1.0, 2.0]], - [[3.0, 4.0]] - ] - }, - "grad": { - "input": [ - [[3.0, 0.0]], - [[2.0, 1.0]], - [[0.0, 0.0]] - ] - } - } - }, - "test_tanh_1": { "model": { diff --git a/tests/json_files/primitives_directed_test.json b/tests/json_files/primitives_directed_test.json index 43ca76d8..b26fb856 100644 --- a/tests/json_files/primitives_directed_test.json +++ b/tests/json_files/primitives_directed_test.json @@ -1023,67 +1023,6 @@ } } }, - "test_tensor_slice_1": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 1, - "step": null - } - }, - "inputs": { - "input": [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0] - ] - }, - - "output_grad": [[5.0, 6.0]], - - "results": { - "eval": [[1.0, 2.0]], - "grad": { - "input": [[5.0, 6.0], - [0.0, 0.0], - [0.0, 0.0]] - } - } - }, - - "test_tensor_slice_2": { - - "model": { - "name": "TensorSlice", - "args": { - "start": 0, - "stop": 2, - "step": null - } - }, - "inputs": { - "input": [ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0] - ] - }, - - "output_grad": [[3.0, 0.0], [2.0, 1.0]], - - "results": { - "eval": [[1.0, 2.0], [3.0, 4.0]], - "grad": { - "input": [ - [3.0, 0.0], - [2.0, 1.0], - [0.0, 0.0] - ] - } - } - }, "test_tanh_1": { "model": "Tanh", "inputs": { diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index 38b1595e..82c0609b 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -2055,22 +2055,6 @@ }, "iterations": 5 }, - "test_tensor_slice": { - "model": { - "name": "TensorSlice", - "randomized_args": { - "start": [1,3], - "stop": [20, 30], - "step": [1, 4] - } - }, - "input_info": { - "input": { - "shapes": [[35,40], [3,3], [3,3], [3,3], [3,3]] - } - }, - "iterations": 20 - }, "test_where": { "model": { "name": "Where" diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index 43421af3..1206352d 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -18,6 +18,7 @@ from mithril.framework import Scalar, Tensor from mithril.framework.common import ( NOT_GIVEN, + TBD, ConnectionType, GenericTensorType, IOKey, @@ -35,7 +36,8 @@ Model, PrimitiveModel, Relu, - TensorSlice, + Slice, + TensorItem, Transpose, ) @@ -819,23 +821,24 @@ def test_shape_constraint_counter_12(): def test_shape_constraint_counter_13(): model = Model() - - model_1 = TensorSlice(start=0, stop=2, step=None) + slice_model = Slice(start=0, stop=2, step=None) + model_1 = TensorItem(index=TBD) model_2 = Add() model_3 = Add() model_4 = Add() - - model += model_1 + model += slice_model + model += model_1(input="", index=slice_model.output) model += model_2 model += model_3 model += model_4 ref_dict = make_reference_dict( { - model_1.input: [1, 1], - model_1.start: [1], - model_1.stop: [1], - model_1.step: [1], - model_2.left: [1, 1, 1, 2], + slice_model.start: [], + slice_model.stop: [], + slice_model.step: [], + model_1.input: [1, 2], + model_1.index: [2], + model_2.left: [1, 1, 2, 2], model_2.right: [1, 2], model_3.left: [1, 1, 2, 2], model_3.right: [1, 2], @@ -849,11 +852,12 @@ def test_shape_constraint_counter_13(): model_2.set_shapes({"right": [1]}) ref_dict = make_reference_dict( { - model_1.input: [1, 1], - model_1.start: [1], - model_1.stop: [1], - model_1.step: [1], - model_2.left: [1, 1, 1, 3], + slice_model.start: [], + slice_model.stop: [], + slice_model.step: [], + model_1.input: [1, 2], + model_1.index: [2], + model_2.left: [1, 1, 2, 3], model_2.right: [1, 3], model_3.left: [1, 1, 2, 3], model_3.right: [1, 2], @@ -944,34 +948,55 @@ def test_shape_constraint_counter_14(): def test_shape_constraint_counter_15(): model = Model() - model_1 = TensorSlice(start=1, stop=None, step=None) - model_2 = TensorSlice(start=1, stop=None, step=None) - model_3 = TensorSlice(start=1, stop=None, step=None) - model_4 = TensorSlice(start=1, stop=None, step=None) + slice_1 = Slice(start=TBD, stop=TBD, step=TBD) + slice_2 = Slice(start=TBD, stop=TBD, step=TBD) + slice_3 = Slice(start=TBD, stop=TBD, step=TBD) + slice_4 = Slice(start=TBD, stop=TBD, step=TBD) - model += model_1 - model += model_2 - model += model_3 - model += model_4 + item_model_1 = TensorItem() + item_model_2 = TensorItem() + item_model_3 = TensorItem() + item_model_4 = TensorItem() + + model_1 = Model() + model_1 += slice_1(start="start", stop="stop", step="step") + model_1 += item_model_1(input="input", index=slice_1.output, output=IOKey("output")) + + model_2 = Model() + model_2 += slice_2(start="start", stop="stop", step="step") + model_2 += item_model_2(input="input", index=slice_2.output, output=IOKey("output")) + + model_3 = Model() + model_3 += slice_3(start="start", stop="stop", step="step") + model_3 += item_model_3(input="input", index=slice_3.output, output=IOKey("output")) + + model_4 = Model() + model_4 += slice_4(start="start", stop="stop", step="step") + model_4 += item_model_4(input="input", index=slice_4.output, output=IOKey("output")) + + model += model_1(start=1, stop=None, step=None) + model += model_2(start=1, stop=None, step=None) + model += model_3(start=1, stop=None, step=None) + model += model_4(start=1, stop=None, step=None) ref_dict = make_reference_dict( { - model_1.input: [1, 2], - model_1.start: [2], - model_1.stop: [2], - model_1.step: [2], - model_2.input: [1, 1, 2, 3], - model_2.start: [3], - model_2.stop: [3], - model_2.step: [3], - model_3.input: [1, 1, 3, 3], - model_3.start: [3], - model_3.stop: [3], - model_3.step: [3], - model_4.input: [1, 1, 2, 3], - model_4.start: [2], - model_4.stop: [2], - model_4.step: [2], - model_4.output: [1, 2], + model_1.input: [1, 2], # type: ignore + model_1.start: [], # type: ignore + model_1.stop: [], # type: ignore + model_1.step: [], # type: ignore + model_2.input: [1, 1, 2, 2], # type: ignore + model_2.start: [], # type: ignore + model_2.stop: [], # type: ignore + model_2.step: [], # type: ignore + model_3.input: [1, 1, 2, 2], # type: ignore + model_3.start: [], # type: ignore + model_3.stop: [], # type: ignore + model_3.step: [], # type: ignore + model_4.input: [1, 1, 2, 2], # type: ignore + model_4.start: [], # type: ignore + model_4.stop: [], # type: ignore + model_4.step: [], # type: ignore + model_4.output: [1, 2], # type: ignore } ) assert_constr_counts(ref_dict) @@ -979,23 +1004,23 @@ def test_shape_constraint_counter_15(): model_1.set_shapes({"input": [9]}) ref_dict = make_reference_dict( { - model_1.input: [1], - model_1.start: [], - model_1.stop: [], - model_1.step: [], - model_2.input: [1, 1], - model_2.start: [], - model_2.stop: [], - model_2.step: [], - model_3.input: [1, 1], - model_3.start: [], - model_3.stop: [], - model_3.step: [], - model_4.input: [1, 1], - model_4.start: [], - model_4.stop: [], - model_4.step: [], - model_4.output: [1], + model_1.input: [1], # type: ignore + model_1.start: [], # type: ignore + model_1.stop: [], # type: ignore + model_1.step: [], # type: ignore + model_2.input: [1, 1], # type: ignore + model_2.start: [], # type: ignore + model_2.stop: [], # type: ignore + model_2.step: [], # type: ignore + model_3.input: [1, 1], # type: ignore + model_3.start: [], # type: ignore + model_3.stop: [], # type: ignore + model_3.step: [], # type: ignore + model_4.input: [1, 1], # type: ignore + model_4.start: [], # type: ignore + model_4.stop: [], # type: ignore + model_4.step: [], # type: ignore + model_4.output: [1], # type: ignore } ) assert_constr_counts(ref_dict) diff --git a/tests/scripts/test_primitive_directed.py b/tests/scripts/test_primitive_directed.py index a1e43563..cf8ea49b 100644 --- a/tests/scripts/test_primitive_directed.py +++ b/tests/scripts/test_primitive_directed.py @@ -1329,48 +1329,6 @@ def test_transpose_axis_4(): ) -def test_tensor_slice_1(): - start = 0 - stop = 1 - step = None - input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - result = np.array([[1.0, 2.0]]) - - output_grad = np.array([[5.0, 6.0]]) - input_grad = np.array([[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]]) - - assert_forward("tensor_slice", result, (input, start, stop, step), {}) - assert_backward( - "tensor_slice", - (input_grad,), - output_grad, - [0], - {"input": input, "start": start, "stop": stop, "step": step}, - {}, - ) - - -def test_tensor_slice_2(): - start = 0 - stop = 2 - step = None - input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) - result = np.array([[1.0, 2.0], [3.0, 4.0]]) - - output_grad = np.array([[3.0, 0.0], [2.0, 1.0]]) - input_grad = np.array([[3.0, 0.0], [2.0, 1.0], [0.0, 0.0]]) - - assert_forward("tensor_slice", result, (input, start, stop, step), {}) - assert_backward( - "tensor_slice", - (input_grad,), - output_grad, - [0], - {"input": input, "start": start, "stop": stop, "step": step}, - {}, - ) - - def test_tanh_1(): input = np.array([[10.0]]) result = np.array([[0.9999999958776928]]) diff --git a/tests/scripts/test_recurrent_models.py b/tests/scripts/test_recurrent_models.py index 2f8b1135..21557293 100644 --- a/tests/scripts/test_recurrent_models.py +++ b/tests/scripts/test_recurrent_models.py @@ -36,9 +36,10 @@ OneToMany, ScalarItem, Shape, + Slice, Sum, Tanh, - TensorSlice, + TensorItem, TrainModel, ) from mithril.utils.utils import pack_data_into_time_slots @@ -194,8 +195,12 @@ def __init__( shp_model = Shape() scalar_item = ScalarItem() - slice_model_1 = TensorSlice(start=TBD) - slice_model_2 = TensorSlice(stop=TBD) + slice_1 = Slice(start=TBD) + slice_2 = Slice(stop=TBD) + + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() + mult_model_1 = MatrixMultiply() mult_model_2 = MatrixMultiply() sum_model_1 = Add() @@ -207,12 +212,15 @@ def __init__( self += shp_model(input="input") self += scalar_item(input=shp_model.output, index=0) - self += slice_model_1( - input="prev_hidden", start=scalar_item.output, output=IOKey("hidden_compl") + self += slice_1(start=scalar_item.output) + self += tensor_item_1( + input="prev_hidden", index=slice_1.output, output=IOKey("hidden_compl") ) - self += slice_model_2(input="prev_hidden", stop=scalar_item.output) + + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) self += mult_model_1(left="input", right="w_ih") - self += mult_model_2(left=slice_model_2.output, right="w_hh") + self += mult_model_2(left=tensor_item_2.output, right="w_hh") self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2(left=sum_model_1.output, right="bias_hh") self += sum_model_3( @@ -310,8 +318,13 @@ def __init__( shp_model = Shape() scalar_item = ScalarItem() - slice_model_1 = TensorSlice(start=TBD) - slice_model_2 = TensorSlice(stop=TBD) + + slice_1 = Slice(start=TBD) + slice_2 = Slice(stop=TBD) + + tensor_item_1 = TensorItem() + tensor_item_2 = TensorItem() + mult_model_1 = MatrixMultiply() mult_model_2 = MatrixMultiply() sum_model_1 = Add() @@ -321,12 +334,14 @@ def __init__( self += shp_model(input="input") self += scalar_item(input=shp_model.output, index=0) - self += slice_model_1( - input="prev_hidden", start=scalar_item.output, output=IOKey("hidden_compl") + self += slice_1(start=scalar_item.output) + self += tensor_item_1( + input="prev_hidden", index=slice_1.output, output=IOKey("hidden_compl") ) - self += slice_model_2(input="prev_hidden", stop=scalar_item.output) + self += slice_2(stop=scalar_item.output) + self += tensor_item_2(input="prev_hidden", index=slice_2.output) self += mult_model_1(left="input", right="w_ih") - self += mult_model_2(left=slice_model_2.output, right="w_hh") + self += mult_model_2(left=tensor_item_2.output, right="w_hh") self += sum_model_1(left=mult_model_1.output, right=mult_model_2.output) self += sum_model_2(left=sum_model_1.output, right="bias_hh") self += sum_model_3( From 4658244db27363b938642b48ab2023937f99fa3b Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Wed, 25 Dec 2024 14:38:05 +0300 Subject: [PATCH 6/7] minor bug fixed in physical shape --- mithril/framework/physical/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index d7e8340e..ed7e8443 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -512,6 +512,9 @@ def _pre_compile( self.jacobian_keys = jacobian_keys self.ignore_grad_keys: set[str] = set() + # Set given shapes. + self.data_store.set_shapes(shapes) + for node in self._flat_graph.nodes.values(): conn_data = node.model.conns.get_connection("output") assert conn_data is not None @@ -552,9 +555,6 @@ def _pre_compile( self.data_store.constraint_solver(updates) - # Set given shapes. - self.data_store.set_shapes(shapes) - # Set given static keys self.data_store.set_static_keys(constant_keys) From 6a8a9d411cc43514df70ff8ee13083a468919816 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 26 Dec 2024 16:30:59 +0300 Subject: [PATCH 7/7] tests are added for item + slice --- .../framework/logical/essential_primitives.py | 6 +- mithril/models/models.py | 16 ++-- tests/scripts/test_all_models.py | 75 +++++++++++++++++++ tests/scripts/test_primitive_directed.py | 37 +++++++-- tests/scripts/test_recurrent_models.py | 9 +-- 5 files changed, 119 insertions(+), 24 deletions(-) diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index bec61c82..eb8a02e8 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -1425,9 +1425,9 @@ class Slice(PrimitiveModel): def __init__( self, - start: int | None | ToBeDetermined = 0, - stop: int | None | ToBeDetermined = None, - step: int | None | ToBeDetermined = None, + start: int | None | ToBeDetermined = TBD, + stop: int | None | ToBeDetermined = TBD, + step: int | None | ToBeDetermined = TBD, name: str | None = None, ): super().__init__( diff --git a/mithril/models/models.py b/mithril/models/models.py index 796c49f0..4f72a543 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -1370,8 +1370,8 @@ def __init__( shape = Shape() scalar_item = ScalarItem() - slice_1 = Slice(start=TBD) - slice_2 = Slice(stop=TBD) + slice_1 = Slice(stop=None, step=None) + slice_2 = Slice(start=None, step=None) tensor_item_1 = TensorItem() tensor_item_2 = TensorItem() mult_model_1 = Linear(use_bias=False) @@ -1508,11 +1508,11 @@ def __init__( shape_model = Shape() scalar_item = ScalarItem() - slice_1 = Slice(stop=TBD) - slice_2 = Slice(stop=TBD) - slice_3 = Slice(start=TBD) - slice_4 = Slice(stop=TBD) - slice_5 = Slice(start=TBD) + slice_1 = Slice(start=None, step=None) + slice_2 = Slice(start=None, step=None) + slice_3 = Slice(stop=None, step=None) + slice_4 = Slice(start=None, step=None) + slice_5 = Slice(stop=None, step=None) tensor_item_1 = TensorItem() tensor_item_2 = TensorItem() @@ -1813,7 +1813,7 @@ def __init__( # current time step. shape_model = Shape() item_model = ScalarItem() - slice_model = Slice(stop=TBD) + slice_model = Slice(start=None, step=None) tensor_item = TensorItem() self += shape_model(input=f"target{idx}") diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index f6e03a14..066b7fe9 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -40,6 +40,7 @@ Greater, GreaterEqual, GroupNorm, + IOKey, IsNan, Less, LessEqual, @@ -50,6 +51,7 @@ LogicalOr, LogicalXOr, Minus, + Model, NanToNum, NormModifier, NotEqual, @@ -65,6 +67,7 @@ Slice, SquaredError, Squeeze, + TensorItem, ToList, ToTensor, ToTuple, @@ -3499,3 +3502,75 @@ def test_slice_all_keys_given_all_three_parts(): tolerances=1e-6, ignore_transform={"output", "step", "start", "stop"}, ) + + +def test_tensor_item_with_slice_1(): + model = Model() + + item_model = TensorItem() + slice_model = Slice(start=0, stop=1, step=None) + + model += slice_model + model += item_model(input="input", index=slice_model.output, output=IOKey("output")) + + input = {"input": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]} + + out_grad = {"output": [[5.0, 6.0]]} + + ref_out = {"output": [[1.0, 2.0]]} + + ref_grad = {"input": [[5.0, 6.0], [0.0, 0.0], [0.0, 0.0]]} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {"input"}, + "inference": False, + "jit": False, + }, + data={}, + params=input, + output_gradients=out_grad, + reference_outputs=ref_out, + reference_gradients=ref_grad, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"step", "start", "stop"}, + ) + + +def test_tensor_item_with_slice_2(): + model = Model() + + item_model = TensorItem() + slice_model = Slice(start=0, stop=2, step=None) + + model += slice_model + model += item_model(input="input", index=slice_model.output, output=IOKey("output")) + + input = {"input": [[[1.0, 2.0]], [[3.0, 4.0]], [[5.0, 6.0]]]} + + out_grad = {"output": [[[3.0, 0.0]], [[2.0, 1.0]]]} + + ref_out = {"output": [[[1.0, 2.0]], [[3.0, 4.0]]]} + + ref_grad = {"input": [[[3.0, 0.0]], [[2.0, 1.0]], [[0.0, 0.0]]]} + + compile_and_compare( + model=model, + compile_kwargs={ + "constant_keys": {}, + "trainable_keys": {"input"}, + "inference": False, + "jit": False, + }, + data={}, + params=input, + output_gradients=out_grad, + reference_outputs=ref_out, + reference_gradients=ref_grad, + assert_shapes=False, + tolerances=1e-6, + ignore_transform={"step", "start", "stop"}, + ) diff --git a/tests/scripts/test_primitive_directed.py b/tests/scripts/test_primitive_directed.py index cf8ea49b..42e069dd 100644 --- a/tests/scripts/test_primitive_directed.py +++ b/tests/scripts/test_primitive_directed.py @@ -30,7 +30,7 @@ def assert_forward( formula_key: str, - expected_result: np.ndarray | int | float | tuple | list, + expected_result: np.ndarray | int | float | tuple | list | slice, args: Any, kwargs: dict[str, Any], backends: list[Backend] = backends, @@ -45,13 +45,16 @@ def assert_forward( } primitive_fn = backend.primitive_function_dict[formula_key] result = primitive_fn(*_args, **_kwargs) - np.testing.assert_allclose( - result, - expected_result, - rtol=1e-14, - atol=1e-14, - err_msg=f"Primitive: {formula_key} failed ", - ) + if not isinstance(expected_result, np.ndarray | tuple | list): + assert result == expected_result + else: + np.testing.assert_allclose( + result, + expected_result, + rtol=1e-14, + atol=1e-14, + err_msg=f"Primitive: {formula_key} failed ", + ) def manul_vjp( @@ -4550,3 +4553,21 @@ def test_split_4(): {"input": input, "split_size": split_size, "axis": axis}, {}, ) + + +def test_slice_1(): + start = 2 + stop = 4 + step = 1 + result = slice(2, 4, 1) + + assert_forward("primitive_slice", result, (start, stop, step), {}) + + +def test_slice_2(): + start = 3 + stop = None + step = None + result = slice(3, None, None) + + assert_forward("primitive_slice", result, (start, stop, step), {}) diff --git a/tests/scripts/test_recurrent_models.py b/tests/scripts/test_recurrent_models.py index 21557293..66a4f70b 100644 --- a/tests/scripts/test_recurrent_models.py +++ b/tests/scripts/test_recurrent_models.py @@ -22,7 +22,6 @@ from mithril import TorchBackend from mithril.framework.common import NOT_GIVEN, ConnectionType from mithril.models import ( - TBD, AbsoluteError, Add, Buffer, @@ -195,8 +194,8 @@ def __init__( shp_model = Shape() scalar_item = ScalarItem() - slice_1 = Slice(start=TBD) - slice_2 = Slice(stop=TBD) + slice_1 = Slice(stop=None, step=None) + slice_2 = Slice(start=None, step=None) tensor_item_1 = TensorItem() tensor_item_2 = TensorItem() @@ -319,8 +318,8 @@ def __init__( shp_model = Shape() scalar_item = ScalarItem() - slice_1 = Slice(start=TBD) - slice_2 = Slice(stop=TBD) + slice_1 = Slice(stop=None, step=None) + slice_2 = Slice(start=None, step=None) tensor_item_1 = TensorItem() tensor_item_2 = TensorItem()