From 5f9193c0c8bd9daf306fbd8c4300c1758539a511 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 14 Nov 2024 10:16:50 +0300 Subject: [PATCH 01/26] feat: Now People will be assigned based on applied labels, github actions bot will request a change if PR title does not follow Conventional format --- .github/workflows/check-pr-title.yml | 8 ++++++- .github/workflows/pr-label-assign.yml | 31 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/pr-label-assign.yml diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml index e10a3584..c04ce31d 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check-pr-title.yml @@ -7,14 +7,20 @@ on: jobs: review_pr_title: runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} steps: + - name: Check out the repository + uses: actions/checkout@v4 - name: Check PR title format id: check_title run: | TITLE="${{ github.event.pull_request.title }}" if [[ ! "$TITLE" =~ ^(feat|fix|docs|chore|perf|test|refactor): ]]; then echo "::error::PR title does not follow the convention 'type: description'. Please review the title." + gh pr review ${{ github.event.pull_request.number }} -r -b "This PR does not satisfies PR title format. Title format should be same as one of the conventional commits" exit 1 else echo "PR title meets the required format." - fi \ No newline at end of file + fi + \ No newline at end of file diff --git a/.github/workflows/pr-label-assign.yml b/.github/workflows/pr-label-assign.yml new file mode 100644 index 00000000..909d0894 --- /dev/null +++ b/.github/workflows/pr-label-assign.yml @@ -0,0 +1,31 @@ +# This wrokflow is triggered when a PR is labeled. +# It assigns reviewers based on the label applied to the PR. + +name: PR Label Assigner + +on: + pull_request_target: + types: + - labeled + +jobs: + pr-labeler: + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + steps: + + - name: Check out the repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install PyYAML + run: pip install pyyaml + + - name: Assign Reviewers # Assign responsible people based on labels + run: | + python .github/scripts/assign_reviewers.py "${{ github.event.pull_request.number }}" "${{ github.event.label.name }}" \ No newline at end of file From 8309838d5e5258daba5fd5ba585c737362283b82 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 14 Nov 2024 12:37:43 +0300 Subject: [PATCH 02/26] feat: Now actions bot requests a change in the case of tests failed --- .github/workflows/ci-test.yaml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci-test.yaml b/.github/workflows/ci-test.yaml index efdc296c..fa0e864f 100644 --- a/.github/workflows/ci-test.yaml +++ b/.github/workflows/ci-test.yaml @@ -41,4 +41,19 @@ jobs: - name: Upload results to Codecov uses: codecov/codecov-action@v4 with: - token: ${{ secrets.CODECOV_TOKEN }} + token: ${{ secrets.CODECOV_TOKEN }} + + on_failure: + needs: build + runs-on: ubuntu-latest + env: + GH_TOKEN: ${{ github.token }} + if: ${{ failure() }} + steps: + - name: Check out the repository + uses: actions/checkout@v4 + - name: review_pr + id: review-pr + run: | + gh pr review ${{ github.event.pull_request.number }} -r -b "Tests are failed. Please review the PR." + exit 1 \ No newline at end of file From 5916d8f74246f27b0d5975b76f0ef77e97d9d27e Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 21 Nov 2024 15:09:05 +0300 Subject: [PATCH 03/26] fix: bug fix attempt 1 --- .github/workflows/check-pr-title.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check-pr-title.yml b/.github/workflows/check-pr-title.yml index c04ce31d..0dae76cd 100644 --- a/.github/workflows/check-pr-title.yml +++ b/.github/workflows/check-pr-title.yml @@ -15,7 +15,7 @@ jobs: - name: Check PR title format id: check_title run: | - TITLE="${{ github.event.pull_request.title }}" + TITLE='${{ github.event.pull_request.title }}' if [[ ! "$TITLE" =~ ^(feat|fix|docs|chore|perf|test|refactor): ]]; then echo "::error::PR title does not follow the convention 'type: description'. Please review the title." gh pr review ${{ github.event.pull_request.number }} -r -b "This PR does not satisfies PR title format. Title format should be same as one of the conventional commits" From dd6535b8740b4c03951ecfe00525d018cddae683 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Sun, 24 Nov 2024 23:56:41 +0300 Subject: [PATCH 04/26] style: fix some function annotations --- mithril/framework/common.py | 197 ++++++++++-------- mithril/framework/constraints.py | 65 +++--- mithril/framework/logical/base.py | 69 +++--- .../framework/logical/essential_primitives.py | 4 +- mithril/framework/logical/model.py | 37 ++-- mithril/framework/logical/primitive.py | 19 +- mithril/framework/physical/model.py | 10 +- mithril/framework/utils.py | 6 +- mithril/models/train_model.py | 4 +- mithril/utils/utils.py | 4 +- tests/scripts/test_constraints.py | 10 +- tests/scripts/test_scripts.py | 14 +- tests/scripts/test_shapes.py | 20 +- tests/scripts/test_summary.py | 12 +- 14 files changed, 260 insertions(+), 211 deletions(-) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index c18f5430..64a81019 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -64,8 +64,8 @@ "Scalar", "ShapesType", "_ShapesType", - "_get_summary_shapes", - "_get_summary_types", + "get_summary_shapes", + "get_summary_types", "ConstraintSolver", "NOT_AVAILABLE", "NotAvailable", @@ -171,9 +171,9 @@ class KeyType(Enum): type[int] | type[float] | type[bool] - | type[tuple] - | type[list] - | type[dict] + | type[tuple[Any, ...]] + | type[list[Any]] + | type[dict[Any, Any]] | type[Constant] | type[slice] | type[PaddingType] @@ -190,9 +190,23 @@ class KeyType(Enum): MainValueType = ( int | float - | tuple - | list - | dict + | tuple[Any, ...] + | list[Any] + | dict[Any, Any] + | None + | EllipsisType + | PaddingType + | Constant + | slice + | Dtype +) +# Mainvalue type for isintance check +MainValueInstance = ( + int + | float + | tuple # type: ignore + | list # type: ignore + | dict # type: ignore | None | EllipsisType | PaddingType @@ -200,7 +214,8 @@ class KeyType(Enum): | slice | Dtype ) -TensorValueType = int | float | tuple | list | Constant | NestedListType +TensorValueType = int | float | tuple[Any, ...] | list[Any] | Constant | NestedListType + LossKey = "loss" FinalCost = "final_cost" @@ -443,8 +458,8 @@ def match(self, other: ConstraintSolver) -> Updates: @dataclass class Updates: - shape_updates: set[Tensor] = field(default_factory=lambda: set()) - value_updates: set[Tensor | Scalar] = field(default_factory=lambda: set()) + shape_updates: set[Tensor[Any]] = field(default_factory=lambda: set()) + value_updates: set[Tensor[Any] | Scalar] = field(default_factory=lambda: set()) uniadic_updates: set[Uniadic] = field(default_factory=lambda: set()) node_updates: set[ShapeNode] = field(default_factory=lambda: set()) constraints: dict[UpdateType, set[Constraint]] = field( @@ -490,7 +505,7 @@ def _add_variadic(self, symbol: Variadic): self.shape_updates.add(tensor) self.constraints[UpdateType.SHAPE] |= tensor.shape_constraints - def _add_type_update(self, symbol: Tensor | Scalar): + def _add_type_update(self, symbol: Tensor[Any] | Scalar): self.constraints[UpdateType.TYPE] |= symbol.type_constraints def __ior__(self, other: Updates) -> Updates: @@ -504,12 +519,12 @@ def __ior__(self, other: Updates) -> Updates: def _get_shapes( - data_dict: dict[str, Tensor | Scalar], - uniadic_keys=None, - varadic_keys=None, + data_dict: dict[str, Tensor[Any] | Scalar], + uniadic_keys: dict[UniadicRecord, str] | None = None, + varadic_keys: dict[Variadic, str] | None = None, symbolic: bool = True, verbose: bool = False, - key_mappings: dict | None = None, + key_mappings: dict[str, str] | None = None, ) -> _ShapesType: if key_mappings is None: key_mappings = {} @@ -550,10 +565,10 @@ def is_valued(self) -> bool: def all_constraints(self) -> set[Constraint]: return self.shape_constraints | self.type_constraints - def _convert_value(self, backend: Backend) -> Any: + def _convert_value(self, backend: Backend[Any]) -> Any: raise NotImplementedError("No '_convert_value' method implemented.") - def finalize_match(self, other: Tensor | Scalar): + def finalize_match(self, other: Tensor[Any] | Scalar): if (typ_1 := type(other)) != (typ_2 := type(self)): raise TypeError( f"Replacement can be done for only same types. Got {typ_1} and {typ_2}" @@ -617,15 +632,15 @@ def remove_constraint(self, constraint: Constraint): def match_shapes(self, other: BaseData): return Updates() - def match(self, other: Tensor | Scalar) -> Updates: + def match(self, other: Tensor[Any] | Scalar) -> Updates: updates = Updates() if self != other: updates = Updates() updates |= self.set_type(other._type) updates |= other.set_type(self._type) if isinstance(other, Tensor): - updates |= self.match_shapes(other) assert isinstance(self, Tensor) + updates |= self.match_shapes(other) is_diff = self._differentiable | other._differentiable self._differentiable = other._differentiable = is_diff @@ -681,7 +696,7 @@ def is_valued(self) -> bool: def _set_as_physical(self): super()._set_as_physical() - def _convert_value(self, backend: Backend) -> DataType | ToBeDetermined: + def _convert_value(self, backend: Backend[Any]) -> DataType | ToBeDetermined: if isinstance(self.temp_value, Constant): self.value = backend.array( epsilon_table[backend.precision][self.temp_value] @@ -690,7 +705,9 @@ def _convert_value(self, backend: Backend) -> DataType | ToBeDetermined: self.value = backend.array(self.temp_value) return self.value - def make_physical(self, backend: Backend, memo: dict[int, Tensor | Scalar]): + def make_physical( + self, backend: Backend[Any], memo: dict[int, Tensor[Any] | Scalar] + ): physical_tensor = deepcopy(self, memo) # Update data as physical data. physical_tensor._set_as_physical() @@ -698,7 +715,7 @@ def make_physical(self, backend: Backend, memo: dict[int, Tensor | Scalar]): physical_tensor._convert_value(backend) return physical_tensor - def __deepcopy__(self, memo: dict[int, Tensor | Scalar]): + def __deepcopy__(self, memo: dict[int, Tensor[Any] | Scalar]): # Check if the object is already in the memo dictionary. if id(self) in memo: return memo[id(self)] @@ -714,7 +731,7 @@ def __deepcopy__(self, memo: dict[int, Tensor | Scalar]): setattr(new_instance, k, deepcopy(v, memo)) return new_instance - def match_shapes(self, other: Tensor): # type: ignore[override] + def match_shapes(self, other: Tensor[Any]): # type: ignore[override] updates = Updates() if other.shape != self.shape: updates |= self.shape.merge(other.shape) @@ -731,9 +748,11 @@ def match_shapes(self, other: Tensor): # type: ignore[override] # to self object? If we should, we also need to transfer "interval" attribute # which requires handling of interval arithmetic in logical level also. - def set_value(self, value: DataType | TensorValueType) -> Updates: # type: ignore[override] + def set_value(self, value: DataType | TensorValueType | str) -> Updates: # type: ignore[override] if self._logical_data: - assert isinstance(value, TensorValueType) + assert isinstance( + value, int | float | tuple | list | Constant | NestedListType + ) return self._set_logical_value(value) else: assert self.is_tensor_type(value) @@ -822,11 +841,13 @@ def find_type(self, value: MainValueType | str) -> ScalarType: else: return find_type(value) - def _convert_value(self, backend): + def _convert_value(self, backend: Backend[Any]): self.value = backend.cast(self.value) return self.value - def make_physical(self, backend: Backend, memo: dict[int, Tensor | Scalar]): + def make_physical( + self, backend: Backend[Any], memo: dict[int, Tensor[Any] | Scalar] + ): new_scalar = deepcopy(self, memo) if id(self) not in memo: # Update data as physical data. @@ -837,7 +858,7 @@ def make_physical(self, backend: Backend, memo: dict[int, Tensor | Scalar]): return new_scalar - def set_value(self, value: MainValueType) -> Updates: + def set_value(self, value: MainValueType | str) -> Updates: # Check value type! updates = Updates() @@ -873,7 +894,7 @@ def __init__( interval = [] self.interval = interval - def construct(self, shape_node: ShapeNode) -> Tensor: + def construct(self, shape_node: ShapeNode) -> Tensor[Any]: return Tensor( shape=shape_node, possible_types=self._type, @@ -884,7 +905,7 @@ def construct(self, shape_node: ShapeNode) -> Tensor: @dataclass class IOHyperEdge: - data: Tensor | Scalar + data: Tensor[Any] | Scalar key_origin: str | None = None @property @@ -1171,7 +1192,7 @@ def __hash__(self) -> int: class Connection(TemplateBase): - def __init__(self, key, metadata, is_key_autogenerated: bool): + def __init__(self, key: str, metadata: IOHyperEdge, is_key_autogenerated: bool): self.data = ConnectionData(key, metadata, is_key_autogenerated, self) @property @@ -1244,13 +1265,7 @@ def __init__(self, *connections: Connection | str, key: IOKey | None = None): self.key = key for item in connections: conn: ConnectionData | str - if isinstance(item, Connection): - conn = item.data - elif isinstance(item, str): - conn = item - else: - raise KeyError("Requires Connection object or string!") - + conn = item.data if isinstance(item, Connection) else item self.connections.add(conn) @@ -1267,6 +1282,19 @@ def __init__(self, *connections: Connection | str, key: IOKey | None = None): | NestedListType ) +ConnectionInstanceType = ( + str + | ConnectionData + | Connect + | MainValueInstance + | ExtendTemplate + | NullConnection + | IOKey + | Connection + | NotAvailable + | NestedListType +) + class Connections: """This class maintains all the connections and their operations in model. @@ -1351,7 +1379,7 @@ def remove_connection(self, connection: ConnectionData) -> None: for _type in KeyType: self._connection_dict[_type].pop(connection.key, None) - def get_data(self, key: str) -> Scalar | Tensor: + def get_data(self, key: str) -> Scalar | Tensor[Any]: # if (metadata := self._get_metadata(key)) is not None: # return metadata.data # raise KeyError(f"Key {key} is not found in connections.") @@ -1428,7 +1456,7 @@ def __hash__(self) -> int: def __eq__(self, other: Uniadic) -> bool: # type: ignore return id(self.metadata) == id(other.metadata) - def set_value(self, value): # Do we need set_value + def set_value(self, value: int | set[int] | None): # Do we need set_value prev_value = self.metadata.possible_values new_value = self.metadata.update_possible_values(value) return prev_value != new_value @@ -1500,9 +1528,9 @@ def update_possible_values(self, values: int | set[int] | None) -> set[int] | No return self.possible_values elif self.possible_values is None: self.possible_values = values - elif values is not None and len(intersect := self.possible_values & values) > 0: + elif len(intersect := self.possible_values & values) > 0: self.possible_values = intersect - elif values is not None and self.possible_values is not None: + else: raise ValueError("Possible values mismatch!") return self.possible_values @@ -2197,7 +2225,7 @@ class ShapeNode: def __init__(self) -> None: self.reprs: list[ShapeRepr] = [] - self.referees: set[Tensor] = set() + self.referees: set[Tensor[Any]] = set() def add_repr(self, repr: ShapeRepr): self.reprs.append(repr) @@ -2206,7 +2234,7 @@ def add_repr(self, repr: ShapeRepr): def merge(self, other: ShapeNode) -> Updates: updates = Updates() resolved_reprs: set[ShapeRepr] = set() - remaining_reprs = [] + remaining_reprs: list[ShapeRepr] = [] add_constraint = False if self != other: @@ -2242,7 +2270,7 @@ def merge(self, other: ShapeNode) -> Updates: def combine(self): updates = Updates() - same_reprs = set() + same_reprs: set[ShapeRepr] = set() # Iterate over all repr pairs and remove matching reprs. for repr, other_repr in combinations(self.reprs, 2): if repr not in same_reprs and other_repr not in same_reprs: @@ -2269,10 +2297,10 @@ def set_values(self, values: Sequence[int | None]): def get_shapes( self, - u_keys: dict[UniadicRecord | Variadic, str] | None = None, - v_keys: dict[UniadicRecord | Variadic, str] | None = None, - symbolic=True, - verbose=False, + u_keys: dict[UniadicRecord, str] | None = None, + v_keys: dict[Variadic, str] | None = None, + symbolic: bool = True, + verbose: bool = False, ) -> list[int | str | None] | list[list[int | str | None]]: if u_keys is None: u_keys = {} @@ -2347,6 +2375,7 @@ def get_most_informative_repr(self) -> ShapeRepr: type ShapeType = Uniadic | Variadic type ConstrainResultType = tuple[bool, Updates] +type ConstraintFunctionType = Callable[..., ConstrainResultType] class ShapeRepr: @@ -2486,9 +2515,9 @@ def _update_uniadics( def get_shapes( self, - u_keys: dict[UniadicRecord | Variadic, str] | None = None, - v_keys: dict[UniadicRecord | Variadic, str] | None = None, - symbolic=True, + u_keys: dict[UniadicRecord, str] | None = None, + v_keys: dict[Variadic, str] | None = None, + symbolic: bool = True, ) -> list[int | str | None]: if u_keys is None: u_keys = {} @@ -2509,8 +2538,12 @@ def get_shapes( return prefix_list + var_list + suffix_list @staticmethod - def _get_uniadic_shapes(uniadic_list, cache, symbolic=True): - final_list = [] + def _get_uniadic_shapes( + uniadic_list: list[Uniadic], + cache: dict[UniadicRecord, str], + symbolic: bool = True, + ): + final_list: list[int | str | None] = [] for uniadic in uniadic_list: if (value := uniadic.value) is None and symbolic: value = cache.setdefault(uniadic.metadata, "u" + str(len(cache) + 1)) @@ -2671,7 +2704,7 @@ class Constraint: call_counter: int = 0 post_processes: set[Callable] = field(default_factory=lambda: set()) - def __call__(self, keys: list[Scalar | Tensor]): + def __call__(self, keys: list[Scalar | Tensor]) -> ConstrainResultType: status = False updates = Updates() if self.type == UpdateType.SHAPE: @@ -2688,7 +2721,7 @@ def __call__(self, keys: list[Scalar | Tensor]): self.call_counter += 1 return status, updates - def add_post_process(self, fn: Callable): + def add_post_process(self, fn: ConstraintFunctionType): self.post_processes.add(fn) def create_post_constraints(self): @@ -3183,23 +3216,12 @@ def create_shape_repr( def get_summary( - conns: dict, + conns: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]], name: str, - shape: dict | None = None, - types: dict | None = None, - params: dict | None = None, + shape: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] | None = None, + types: dict[str, tuple[dict[str, str], dict[str, str]]] | None = None, + params: dict[str, tuple[dict[str, str], dict[str, str]]] | None = None, ) -> Table: - """Constructs the summary table based on connections and shapes - - Args: - conns (dict): connection dict - name (str): given name of the table - shape (dict | None, optional): Shape information of all keys. Defaults to None. - - Returns: - Table: Table object that holds all summary information - """ - stub_input = "Inputs" stub_output = "Outputs" stub_len = len(max(stub_input, stub_output, key=len)) @@ -3214,7 +3236,7 @@ def get_summary( sub_columns = [keys_name, shapes_name, type_name, conn_name, params_name] adjustments = ["left", "right", "left", "left", "right"] - removed_cols = [] + removed_cols: list[int] = [] if shape is not None: align_shapes( [shape_dict for shape_tuple in shape.values() for shape_dict in shape_tuple] @@ -3310,20 +3332,28 @@ def get_summary( in_keys = [[key] for key in input_conn] out_keys = [[key] for key in output_conn] - input_shape = input_shape.values() - output_shape = output_shape.values() + input_shapes = list(input_shape.values()) + output_shapes = list(output_shape.values()) in_types = [[key] for key in input_types.values()] out_types = [[key] for key in output_types.values()] - input_conn = [value if value else ["--"] for value in input_conn.values()] - output_conn = [value if value else ["--"] for value in output_conn.values()] + model_input_conn = [value if value else ["--"] for value in input_conn.values()] + model_output_conn = [ + value if value else ["--"] for value in output_conn.values() + ] in_params = [[param] for param in input_params.values()] out_params = [[param] for param in output_params.values()] - input_args = [in_keys, input_shape, in_types, input_conn, in_params] - output_args = [out_keys, output_shape, out_types, output_conn, out_params] + input_args = [in_keys, input_shapes, in_types, model_input_conn, in_params] + output_args = [ + out_keys, + output_shapes, + out_types, + model_output_conn, + out_params, + ] # remove not given columns input_args = [ @@ -3363,8 +3393,11 @@ def get_summary( return table -def _get_summary_shapes(model_shapes: dict, conn_info: dict): - shape_info = {} +def get_summary_shapes( + model_shapes: dict[str, _ShapesType], + conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]], +): + shape_info: dict[str, tuple[_ShapesType, _ShapesType]] = {} for model_name in conn_info: shape = model_shapes[model_name] input_conns, output_conns = conn_info[model_name] @@ -3374,7 +3407,7 @@ def _get_summary_shapes(model_shapes: dict, conn_info: dict): return shape_info -def _get_summary_types(name_mappings: dict, data_memo=None): +def get_summary_types(name_mappings: dict, data_memo=None): if data_memo is None: data_memo = {} diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index e19e9ac8..bb6f7193 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -13,11 +13,11 @@ # limitations under the License. import math -from collections.abc import Callable, Sequence +from collections.abc import Sequence from functools import reduce from itertools import product, zip_longest from types import EllipsisType, GenericAlias, NoneType, UnionType -from typing import get_origin +from typing import Any, get_origin from ..utils.type_utils import ( is_axis_reduce_type, @@ -36,6 +36,7 @@ TBD, Constant, ConstrainResultType, + ConstraintFunctionType, NestedListType, PossibleValues, Scalar, @@ -148,7 +149,7 @@ def _reduce_union_type( return new_type -def general_tensor_type_constraint(*args: Scalar | Tensor): +def general_tensor_type_constraint(*args: Scalar | Tensor[Any]): # NOTE: Assumes first argument is always output as other constraints. # Also requires all types of args consists of any combination of # float, int and bool. For instance, int | float is an acceptable type @@ -264,7 +265,7 @@ def general_tensor_type_constraint(*args: Scalar | Tensor): def floor_divide_type_constraint( - output: Tensor, numerator: Tensor, denominator: Tensor + output: Tensor[Any], numerator: Tensor[Any], denominator: Tensor[Any] ): status = False updates = Updates() @@ -548,7 +549,7 @@ def scalar_item_type_constraint(output: Scalar, input: Scalar, index: Scalar): return status, updates -def tensor_to_list_type_constraint(output: Scalar, input: Tensor): +def tensor_to_list_type_constraint(output: Scalar, input: Tensor[Any]): status = not is_union(output._type) updates = Updates() assert input._temp_shape is not None @@ -599,7 +600,7 @@ def tensor_to_list_type_constraint(output: Scalar, input: Tensor): return status, updates -def reduce_type_constraint(output: Tensor, input: Tensor): +def reduce_type_constraint(output: Tensor[Any], input: Tensor[Any]): updates = Updates() input_type = input._type @@ -1201,7 +1202,9 @@ def bcast_helper( return bcast_exit_condition(output, left, right, index), updates -def bcast(output: Tensor, left: Tensor, right: Tensor) -> ConstrainResultType: +def bcast( + output: Tensor[Any], left: Tensor[Any], right: Tensor[Any] +) -> ConstrainResultType: assert output._temp_shape is not None, "Output shape of broadcast is not set!" assert left._temp_shape is not None, "Left shape of broadcast is not set!" assert right._temp_shape is not None, "Right shape of broadcast is not set!" @@ -1209,7 +1212,7 @@ def bcast(output: Tensor, left: Tensor, right: Tensor) -> ConstrainResultType: def bcast_matrix_mult( - output: Tensor, left: Tensor, right: Tensor + output: Tensor[Any], left: Tensor[Any], right: Tensor[Any] ) -> ConstrainResultType: assert output._temp_shape is not None, "Output shape of broadcast is not set!" assert left._temp_shape is not None, "Left shape of broadcast is not set!" @@ -1261,9 +1264,9 @@ def bcast_exit_condition( def bcast_error_check( - output: Tensor, - left: Tensor, - right: Tensor, + output: Tensor[Any], + left: Tensor[Any], + right: Tensor[Any], index: int = 0, ) -> ConstrainResultType: assert left._temp_shape is not None, "Left shape of broadcast is not set!" @@ -1346,13 +1349,13 @@ def bcast_is_compatible( def bcast_mat_mul_check( - output: Tensor, left: Tensor, right: Tensor + output: Tensor[Any], left: Tensor[Any], right: Tensor[Any] ) -> ConstrainResultType: return bcast_error_check(output, left, right, index=2) def reduce_constraints( - output: Tensor, input: Tensor, axis: Scalar, keepdim: Scalar | None = None + output: Tensor[Any], input: Tensor[Any], axis: Scalar, keepdim: Scalar | None = None ) -> ConstrainResultType: updates = Updates() assert input._temp_shape is not None, "Input shape of reduce is not set!" @@ -1594,7 +1597,7 @@ def reduce_constraints( def concat_constraints( - output: Tensor, axis: Scalar, *inputs: Tensor + output: Tensor[Any], axis: Scalar, *inputs: Tensor[Any] ) -> ConstrainResultType: status = False updates = Updates() @@ -1716,7 +1719,7 @@ def concat_constraints( def reverse_constraints( - output: Tensor, input: Tensor, axes: Scalar + output: Tensor[Any], input: Tensor[Any], axes: Scalar ) -> ConstrainResultType: status = False assert input._temp_shape is not None, "Input shape of reverse is not set!" @@ -1782,7 +1785,7 @@ def reverse_constraints( def polynomial_features_constraints( - output: Tensor, input: Tensor, degree: Scalar + output: Tensor[Any], input: Tensor[Any], degree: Scalar ) -> ConstrainResultType: status = False updates = Updates() @@ -1880,8 +1883,8 @@ def sliding_window_constraint_helper( def sliding_window_1d_constraints( - output: Tensor, - input: Tensor, + output: Tensor[Any], + input: Tensor[Any], stride: Scalar, padding: Scalar, dilation: Scalar, @@ -2483,7 +2486,7 @@ def validate_bcast(input: ShapeRepr, shape: tuple[int, ...]): def reshape_constraints( - output: Tensor, input: Tensor, shape: Scalar + output: Tensor[Any], input: Tensor[Any], shape: Scalar ) -> ConstrainResultType: # TODO: We can add inference for the case where # shape = (1,2,3,4), input_shape = (1, 2, 4, "u1") for example. @@ -2689,7 +2692,9 @@ def squeeze_constraints(output: Tensor, input: Tensor) -> ConstrainResultType: return status, updates -def size_constraints(output: Scalar, input: Tensor, dim: Scalar) -> ConstrainResultType: +def size_constraints( + output: Scalar, input: Tensor[Any], dim: Scalar +) -> ConstrainResultType: assert input._temp_shape is not None, "Input shape of Size is not set!" input_shape: ShapeRepr = input._temp_shape @@ -2837,7 +2842,7 @@ def size_constraints(output: Scalar, input: Tensor, dim: Scalar) -> ConstrainRes return status, updates -def shape_constraints(output: Scalar, input: Tensor) -> ConstrainResultType: +def shape_constraints(output: Scalar, input: Tensor[Any]) -> ConstrainResultType: assert input._temp_shape is not None, "Input shape of Shape is not set!" input_shape: ShapeRepr = input._temp_shape output_val = output.value @@ -2870,6 +2875,7 @@ def eye_constraints(output: Tensor, N: Scalar, M: Scalar) -> ConstrainResultType m_uni_valued = isinstance(m_uni.value, int) if n_valued and not n_uni_valued: + assert isinstance(N.value, int) n_uni.set_value(N.value) updates.add(n_uni) elif n_uni_valued and not n_valued: @@ -2877,6 +2883,7 @@ def eye_constraints(output: Tensor, N: Scalar, M: Scalar) -> ConstrainResultType updates.add(N) if m_valued and not m_uni_valued: + assert isinstance(M.value, int) m_uni.set_value(M.value) updates.add(m_uni) elif m_uni_valued and not m_valued: @@ -3014,7 +3021,7 @@ def swap_axes_constraints( return status, updates -def to_tensor_constraints(output: Tensor, input: Scalar) -> ConstrainResultType: +def to_tensor_constraints(output: Tensor[Any], input: Scalar) -> ConstrainResultType: updates = Updates() status = False assert output._temp_shape is not None, "Output shape of ToTensor is not set!" @@ -3064,7 +3071,9 @@ def to_tensor_constraints(output: Tensor, input: Scalar) -> ConstrainResultType: return status, updates -def tensor_to_list_constraints(output: Scalar, input: Tensor) -> ConstrainResultType: +def tensor_to_list_constraints( + output: Scalar, input: Tensor[Any] +) -> ConstrainResultType: assert input._temp_shape is not None, "Input shape of TensorToList is not set!" input_shape: ShapeRepr = input._temp_shape output_val = output.value @@ -3096,7 +3105,7 @@ def tensor_to_list_constraints(output: Scalar, input: Tensor) -> ConstrainResult return status, updates -def item_constraints(output: Scalar, input: Tensor) -> ConstrainResultType: +def item_constraints(output: Scalar, input: Tensor[Any]) -> ConstrainResultType: assert input._temp_shape is not None, "Input shape of Item is not set!" input_shape: ShapeRepr = input._temp_shape updates = Updates() @@ -3199,7 +3208,7 @@ def to_list_constraints(output: Scalar, *args: Scalar) -> ConstrainResultType: def tensor_item_constraints( - output: Tensor, input: Tensor, index: Scalar + output: Tensor[Any], input: Tensor[Any], index: Scalar ) -> ConstrainResultType: assert output._temp_shape is not None, "Output shape of TensorItem is not set!" assert input._temp_shape is not None, "Input shape of TensorItem is not set!" @@ -3332,7 +3341,7 @@ def tensor_item_constraint_helper( def tensor_slice_constraints( - output: Tensor, input: Tensor, start: Scalar, stop: Scalar, step: Scalar + output: Tensor[Any], input: Tensor[Any], start: Scalar, stop: Scalar, step: Scalar ) -> ConstrainResultType: assert output._temp_shape is not None, "Output shape of TensorSlice is not set!" assert input._temp_shape is not None, "Input shape of TensorSlice is not set!" @@ -3495,7 +3504,7 @@ def tuple_converter_constraint(output: Scalar, input: Scalar) -> ConstrainResult return status, updates -type_constraints = { +type_constraints: set[ConstraintFunctionType] = { general_tensor_type_constraint, floor_divide_type_constraint, scalar_slice_type_constraint, @@ -3504,7 +3513,7 @@ def tuple_converter_constraint(output: Scalar, input: Scalar) -> ConstrainResult reduce_type_constraint, } -post_process_map: dict[Callable, set[Callable]] = { +post_process_map: dict[ConstraintFunctionType, set[ConstraintFunctionType]] = { bcast: {bcast_error_check}, bcast_matrix_mult: {bcast_mat_mul_check}, } diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index cb39f014..f527e20c 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -15,7 +15,7 @@ from __future__ import annotations import abc -from collections.abc import Callable, Mapping +from collections.abc import Mapping from dataclasses import dataclass from types import UnionType from typing import Any @@ -28,17 +28,21 @@ Connections, ConnectionType, Constraint, + ConstraintFunctionType, ConstraintSolver, IOHyperEdge, MainValueType, NotAvailable, Scalar, + ShapeNode, ShapesType, ShapeTemplateType, ShapeType, Tensor, + UniadicRecord, Updates, UpdateType, + Variadic, _get_shapes, _ShapesType, create_shape_repr, @@ -102,25 +106,25 @@ def summary( symbolic: bool = False, name: str | None = None, alternative_shapes: bool = False, - uni_cache: dict | None = None, - var_cache: dict | None = None, + uni_cache: dict[UniadicRecord, str] | None = None, + var_cache: dict[Variadic, str] | None = None, ) -> None: raise NotImplementedError("Implement summary method!") @property - def enforce_jit(self): + def enforce_jit(self) -> bool: return self._enforce_jit @enforce_jit.setter - def enforce_jit(self, value): + def enforce_jit(self, value: bool) -> None: self._enforce_jit = value @property - def jittable(self): + def jittable(self) -> bool: return self._jittable @property - def shapes(self): + def shapes(self) -> _ShapesType: return self.get_shapes() @property @@ -157,7 +161,10 @@ def _get_outermost_parent(self): return model def _generate_keys( - self, symbolic=True, include_internals=True, include_outputs=False + self, + symbolic: bool = True, + include_internals: bool = True, + include_outputs: bool = False, ) -> dict[str, str]: return {} @@ -181,9 +188,9 @@ def _freeze(self) -> None: ... def extract_connection_info( self, name_mappings: dict[BaseModel, str], - data_to_key_map: dict[Tensor | Scalar, list[str]] | None = None, - data_memo: dict | None = None, - ) -> dict[str, tuple[dict, dict]]: + data_to_key_map: dict[Tensor[Any] | Scalar, list[str]] | None = None, + data_memo: Mapping[int, Tensor[Any] | Scalar] | None = None, + ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: raise NotImplementedError("Implement extract_connection_info method!") def _create_connection( @@ -223,11 +230,9 @@ def _set_shapes( model = self._get_outermost_parent() metadatas: OrderedSet[IOHyperEdge] = OrderedSet() used_keys: dict[str | int, ShapeType] = {} - shape_nodes = {} + shape_nodes: dict[str, ShapeNode] = {} for key, shape in shapes.items(): metadata = self.conns.extract_metadata(key) - if metadata is None: - raise KeyError("Requires valid IO connection to set shapes!") if metadata in metadatas: raise KeyError("shape of same connection has already given") metadatas.add(metadata) @@ -310,7 +315,7 @@ def set_types( # run the constraints for updating affected connections model.constraint_solver(updates) - def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates: + def _set_value(self, key: ConnectionData, value: MainValueType | str) -> Updates: """ Set value for the given connection. @@ -328,7 +333,11 @@ def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates: return key.metadata.data.set_value(value) def get_shapes( - self, uni_keys=None, var_keys=None, symbolic=True, verbose=False + self, + uni_keys: dict[UniadicRecord, str] | None = None, + var_keys: dict[Variadic, str] | None = None, + symbolic: bool = True, + verbose: bool = False, ) -> _ShapesType: return _get_shapes( data_dict={ @@ -343,9 +352,9 @@ def get_shapes( def _set_constraint( self, - fn: Callable, + fn: ConstraintFunctionType, keys: list[str], - post_processes: set[Callable] | None = None, + post_processes: set[ConstraintFunctionType] | None = None, type: UpdateType | None = None, ): constr_conns = [self.conns.all[key] for key in keys] @@ -373,9 +382,9 @@ def _set_constraint( def set_constraint( self, - fn: Callable, + fn: ConstraintFunctionType, keys: list[str], - post_processes: set[Callable] | None = None, + post_processes: set[ConstraintFunctionType] | None = None, type: UpdateType = UpdateType.SHAPE, ) -> None: self.assigned_constraints.append({"fn": fn.__name__, "keys": keys}) @@ -396,9 +405,6 @@ def canonical_output(self) -> Connection | NotAvailable: return self._canonical_output.conn def set_canonical_input(self, given_conn: str | Connection): - if not isinstance(given_conn, str | Connection): - raise ValueError("Set canonical input takes only a 'key' or 'connection'!") - if isinstance(given_conn, str): conn = self.conns.all.get(given_conn) if conn is None: @@ -417,9 +423,6 @@ def set_canonical_input(self, given_conn: str | Connection): self._canonical_input = conn def set_canonical_output(self, given_conn: str | Connection): - if not isinstance(given_conn, str | Connection): - raise ValueError("Set canonical output takes only a 'key' or 'connection'!") - if isinstance(given_conn, str): conn = self.conns.all.get(given_conn) if conn is None: @@ -537,13 +540,13 @@ def __init__(self, connections: Connections) -> None: ] = {} # Add new model to dependency map, model_dag is created in extend - def add_model_dag(self, model: BaseModel, model_dag): + def add_model_dag(self, model: BaseModel, model_dag: dict[str, ConnectionData]): updated_conns: OrderedSet[ConnectionData] = OrderedSet() for local_key, conn in model_dag.items(): if local_key in model.conns.input_keys: - specs = OrderedSet( + specs: OrderedSet[ConnectionData] = OrderedSet( [ - model_dag.get(conn.key) + model_dag[conn.key] for conn in model.dependency_map.get_dependent_output_conns( local_key ) @@ -553,10 +556,10 @@ def add_model_dag(self, model: BaseModel, model_dag): self._local_input_dependency_map[conn] = [(model, specs)] updated_conns.add(conn) - elif local_key in model.conns.output_keys: + else: specs = OrderedSet( [ - model_dag.get(conn.key) + model_dag[conn.key] for conn in model.dependency_map.get_dependent_input_conns( local_key ) @@ -808,7 +811,7 @@ def get_input_key_dependency(self, key: str): for spec in item[1] ] ) - if key in self._local_input_dependency_map + if conn_data in self._local_input_dependency_map else OrderedSet() ) return specs @@ -842,7 +845,7 @@ def get_output_key_dependency(self, key: str): # TODO: add test checking the while key_stack |= ( self._local_output_dependency_map[conn_data][1] - if key in self._local_output_dependency_map + if conn_data in self._local_output_dependency_map else OrderedSet() ) return specs diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 1860ca0c..01458e6f 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -906,7 +906,7 @@ class SingleInputOperation(PrimitiveModel): def __init__( self, - formula_key, + formula_key: str, polymorphic_constraint: bool = True, **kwargs: TensorType | Scalar, ) -> None: @@ -914,7 +914,7 @@ def __init__( output=TensorType([("Var", ...)]), input=TensorType([("Var", ...)]) ) # Finalize kwargs. - new_kwargs: Mapping = default_kwargs | kwargs + new_kwargs: Mapping[str, TensorType | Scalar] = default_kwargs | kwargs super().__init__(formula_key, **new_kwargs) if polymorphic_constraint: diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 4f4e00ea..51338af5 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -14,8 +14,9 @@ from __future__ import annotations +from collections.abc import Mapping from types import UnionType -from typing import Self, TypeVar, overload +from typing import Any, Self, TypeVar, overload from ...core import Constant from ...utils.utils import OrderedSet, find_dominant_type @@ -26,11 +27,13 @@ Connect, Connection, ConnectionData, + ConnectionInstanceType, ConnectionType, ExtendTemplate, IOHyperEdge, IOKey, KeyType, + MainValueInstance, MainValueType, NullConnection, Scalar, @@ -40,9 +43,9 @@ ToBeDetermined, Updates, Variadic, - _get_summary_shapes, - _get_summary_types, get_summary, + get_summary_shapes, + get_summary_types, ) from ..utils import define_unique_names from .base import ExtendInfo @@ -233,9 +236,9 @@ def set_outputs(self, *args: str | Connection, **kwargs: str | Connection) -> No # Merge new_conn with given connection. self.merge_connections(new_conn, conn_data) - def _set_value(self, key: ConnectionData, value: MainValueType) -> Updates: + def _set_value(self, key: ConnectionData, value: MainValueType | str) -> Updates: if isinstance(key.metadata.data, Tensor): - extend_value: MainValueType | IOKey = value + extend_value: MainValueType | IOKey | str = value # If ToTensor output key is reserved key, rename it. if key.conn.key == "input": self.inter_key_count += 1 @@ -327,11 +330,11 @@ def _add_connection( match_connection = None if isinstance( - given_connection, MainValueType | NullConnection + given_connection, MainValueInstance | NullConnection ): # or given_connection == NOT_GIVEN: # Immediate values can be provided only for inputs. # if given_connection != NOT_GIVEN: - if isinstance(given_connection, MainValueType): + if isinstance(given_connection, MainValueInstance): set_value = given_connection if expose is None: @@ -462,7 +465,7 @@ def _unroll_template( ) else: assert isinstance( - connection, ConnectionType + connection, ConnectionInstanceType ) # TODO: check if needed connections.append(connection) self.extend( @@ -479,7 +482,7 @@ def _unroll_template( for local_key, outer_con in zip( model._input_keys, connections, strict=False ): - if isinstance(outer_con, MainValueType): + if isinstance(outer_con, MainValueInstance): conn = model.conns.get_connection(local_key) assert conn is not None conn_data = conn.metadata @@ -607,7 +610,7 @@ def handle_auto_conversion( updates: Updates, ) -> ConnectionType: connection_type: type[Tensor] | type[Scalar] | None = None - if isinstance(connection, MainValueType): + if isinstance(connection, MainValueInstance): connection_type = Scalar elif isinstance(connection, ConnectionData): @@ -666,12 +669,12 @@ def handle_auto_conversion( # Create data object based on given_value or given key_type. if is_value_given: - assert isinstance(set_value, MainValueType | str) + assert isinstance(set_value, MainValueInstance | str) data = Scalar(value=set_value) elif key_type == Scalar: if set_type is None: - set_type = MainValueType | type[str] + set_type = MainValueInstance | str data = Scalar(possible_types=set_type) else: @@ -1099,7 +1102,7 @@ def extend( value, type(template_conn.metadata.data) ) - elif isinstance(value, MainValueType): + elif isinstance(value, MainValueInstance): if key in model.conns.output_keys: raise KeyError( f"{key} key is an output of the model, output values could " @@ -1110,7 +1113,7 @@ def extend( # Hold shape information for IOKey type values in order # to set all in a bulk after all connections are added. if value._shape is not None: - shape_info |= {key: value._shape} + shape_info |= {key: value._shape} # if value._type is not None: type_info[key] = value._type @@ -1508,11 +1511,11 @@ def summary( } if shapes: # extract model shapes - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) if types: # extract model types - type_info = _get_summary_types(name_mappings) + type_info = get_summary_types(name_mappings) if not name: name = self.__class__.__name__ @@ -1545,7 +1548,7 @@ def extract_connection_info( self, name_mappings: dict[BaseModel, str], data_to_key_map: dict[Tensor | Scalar, list[str]] | None = None, - data_memo: dict | None = None, + data_memo: Mapping[int, Tensor[Any] | Scalar] | None = None, ): conn_info: dict[str, tuple[dict, dict]] = {} if self._input_keys: diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 415eada1..d3cd083a 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Mapping +from typing import Any + from ...utils.utils import OrderedSet from ..common import ( NOT_AVAILABLE, @@ -24,10 +27,10 @@ Tensor, TensorType, Updates, - _get_summary_shapes, - _get_summary_types, create_shape_map, get_summary, + get_summary_shapes, + get_summary_types, ) from ..utils import define_unique_names from .base import BaseModel @@ -42,7 +45,9 @@ class PrimitiveModel(BaseModel): cache_name = "cache" output: Connection - def __init__(self, formula_key, **kwargs: Tensor | TensorType | Scalar) -> None: + def __init__( + self, formula_key: str, **kwargs: Tensor[Any] | TensorType | Scalar + ) -> None: self.formula_key = formula_key self.grad_formula = formula_key + "_grad" @@ -134,8 +139,8 @@ def convert_to_tuple(value: int | tuple[int, int] | list) -> tuple[int, int]: def extract_connection_info( self, name_mappings: dict[BaseModel, str], - data_to_key_map: dict[Scalar | Tensor, list[str]] | None = None, - data_memo: dict | None = None, + data_to_key_map: dict[Scalar | Tensor[Any], list[str]] | None = None, + data_memo: Mapping[int, Tensor[Any] | Scalar] | None = None, ): if data_to_key_map is None: data_to_key_map = {} @@ -197,11 +202,11 @@ def summary( } if shapes: # extract model shapes - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) if types: # extract model types - type_info = _get_summary_types(name_mappings) + type_info = get_summary_types(name_mappings) if not name: name = self.__class__.__name__ diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index fb736b91..92315c25 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -32,11 +32,11 @@ Updates, Variadic, _get_shapes, - _get_summary_shapes, - _get_summary_types, _ShapesType, create_shape_map, get_summary, + get_summary_shapes, + get_summary_types, ) from ..logical.base import BaseModel from ..logical.model import Model @@ -1014,7 +1014,7 @@ def summary( name_mappings = define_unique_names(all_models) conn_info = self.extract_connection_info(name_mappings) - model_shapes = { + model_shapes: dict[str, _ShapesType] = { sub_model_name: self.get_shapes( sub_model, uni_keys, var_keys, symbolic, alternative_shapes ) @@ -1033,11 +1033,11 @@ def summary( if verbose: if shapes: # extract the shape info if necessary - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) if types: # extract the type info if necessary - type_info = _get_summary_types(name_mappings, self.data_store.data_memo) + type_info = get_summary_types(name_mappings, self.data_store.data_memo) # if verbose, find the name of the model and create the table object and # display it based on extracted infos diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 4dcd2430..63fcd3bd 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -336,7 +336,11 @@ def find_intersection_type( continue args_2 = typ_2.__args__ - assert typ_2.__origin__ is tuple or typ_2.__origin__ is list + assert ( + typ_2.__origin__ is tuple + or typ_2.__origin__ is list + or typ_2.__origin__ is dict + ) if typ_1.__origin__ == typ_2.__origin__: if len(args_1) == 0 or len(args_2) == 0: # if one of the lengths of the args_1 and args_2 are zero, diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index ff9b3960..7442b6a2 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -27,7 +27,7 @@ KeyType, Model, _get_shapes, - _get_summary_shapes, + get_summary_shapes, ) from ..framework.common import TBD, NotAvailable, Table from ..framework.logical import ( @@ -546,7 +546,7 @@ def summary( ), ) - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) if self.loss_keys: # If any loss is attached, extract useful information # about each added loss and print the table diff --git a/mithril/utils/utils.py b/mithril/utils/utils.py index d44a1a1a..fe3fe0a9 100755 --- a/mithril/utils/utils.py +++ b/mithril/utils/utils.py @@ -535,7 +535,9 @@ def convert_to_list( return value -def find_dominant_type(lst, raise_error: bool = True): +def find_dominant_type( + lst: Any, raise_error: bool = True +) -> type[int] | type[float] | type[bool]: # return dominant type of parameters in the list. # dominant type is referenced from numpy and in folloing order: bool -> int -> float # if any of the parameters are different from these three types, returns ValueError diff --git a/tests/scripts/test_constraints.py b/tests/scripts/test_constraints.py index 3c81fc72..ff45d5ed 100644 --- a/tests/scripts/test_constraints.py +++ b/tests/scripts/test_constraints.py @@ -157,7 +157,7 @@ def variadic_update_values( def extract_uniadic_possibles( uni: Uniadic, assignments: AssignmentType, - uni_cache: dict[UniadicRecord | Variadic, str], + uni_cache: dict[UniadicRecord, str], ) -> None: # Takes an uniadic object and fills the assignments dictionary # based on possible values of the uniadic object. @@ -170,8 +170,8 @@ def extract_uniadic_possibles( def extract_variadic_possibles( var: Variadic, assignments: AssignmentType, - uni_cache: dict[UniadicRecord | Variadic, str], - var_cache: dict[UniadicRecord | Variadic, str], + uni_cache: dict[UniadicRecord, str], + var_cache: dict[Variadic, str], ) -> None: assert var.possibles is not None all_possible_values: dict[int, PossibleValues] = var.possibles @@ -203,8 +203,8 @@ def assert_shape_results( data[key] for key in expected_updates } == updated_symbols.shape_updates | updated_symbols.value_updates # Then check final shapes with the expected ref_results. - uni_cache: dict[UniadicRecord | Variadic, str] = {} - var_cache: dict[UniadicRecord | Variadic, str] = {} + uni_cache: dict[UniadicRecord, str] = {} + var_cache: dict[Variadic, str] = {} shapes = {} assignments: AssignmentType = {} for key, value in data.items(): diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index fedc1822..92665a49 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -4299,16 +4299,6 @@ def test_connect_error_3(): assert str(error_info.value) == "Input keys are always exposed!" -def test_connect_error_4(): - model = Model() - model += Relu()(input="input2", output=IOKey(name="output")) - - with pytest.raises(KeyError) as error_info: - model += Relu()(input="input", output=Connect("input2", 3)) # type: ignore - - assert str(error_info.value) == "'Requires Connection object or string!'" - - def test_connect_error_5(): model_2 = Model() model_2 += Tanh()(input="input1", output=IOKey(name="output1")) @@ -6628,8 +6618,8 @@ def test_cyclic_extend(): def assert_repr_dict(data: dict[str, ShapeRepr], ref_shapes: dict): - uni_cache: dict[UniadicRecord | Variadic, str] = {} - var_cache: dict[UniadicRecord | Variadic, str] = {} + uni_cache: dict[UniadicRecord, str] = {} + var_cache: dict[Variadic, str] = {} shapes = { key: value.get_shapes(uni_cache, var_cache) for key, value in data.items() } diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index 6bb294bd..b6be55ce 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -193,8 +193,8 @@ def repr_sort(repr): def get_deterministic_shape(node: ShapeNode): - uni: dict[UniadicRecord | Variadic, str] = {} - var: dict[UniadicRecord | Variadic, str] = {} + uni: dict[UniadicRecord, str] = {} + var: dict[Variadic, str] = {} if len(reprs := node.reprs) != 1: sorted_reprs = sorted(reprs, key=repr_sort, reverse=True) return [repr.get_shapes(uni, var) for repr in sorted_reprs] @@ -239,8 +239,8 @@ def assert_all_reprs_unique(model: BaseModel): """ all_reprs = get_all_reprs(model) - uni_cache: dict[UniadicRecord | Variadic, str] = {} - var_cache: dict[UniadicRecord | Variadic, str] = {} + uni_cache: dict[UniadicRecord, str] = {} + var_cache: dict[Variadic, str] = {} for repr1, repr2 in combinations(all_reprs, 2): repr1_shapes = repr1.get_shapes(uni_cache, var_cache) @@ -271,8 +271,8 @@ def assert_all_integer_uniadics_unique(model: BaseModel): def assert_match_shapes( repr1: ShapeRepr, repr2: ShapeRepr, repr1_ref_shapes: list, repr2_ref_shapes: list ): - uni_cache: dict[UniadicRecord | Variadic, str] = {} - var_cache: dict[UniadicRecord | Variadic, str] = {} + uni_cache: dict[UniadicRecord, str] = {} + var_cache: dict[Variadic, str] = {} repr1._match(repr2) @@ -9878,8 +9878,8 @@ def test_bcast_uniadics(): output = ShapeRepr(root=Variadic(), suffix=[Uniadic(1), Uniadic(2)]) bcast_uniadics(output, left, right, 0) - uni_cache: dict[UniadicRecord | Variadic, str] = {} - var_cache: dict[UniadicRecord | Variadic, str] = {} + uni_cache: dict[UniadicRecord, str] = {} + var_cache: dict[Variadic, str] = {} assert left.get_shapes(uni_cache, var_cache) == ["(V1, ...)", 1, "u1"] assert right.get_shapes(uni_cache, var_cache) == ["(V2, ...)"] assert output.get_shapes(uni_cache, var_cache) == ["(V3, ...)", 1, 2] @@ -9893,8 +9893,8 @@ def test_bcast_align(): output = ShapeRepr(root=Variadic(), suffix=[Uniadic(1), Uniadic(2)]) bacast_align_output(output, left, right, 0) - uni_cache: dict[UniadicRecord | Variadic, str] = {} - var_cache: dict[UniadicRecord | Variadic, str] = {} + uni_cache: dict[UniadicRecord, str] = {} + var_cache: dict[Variadic, str] = {} assert left.get_shapes(uni_cache, var_cache) == ["(V1, ...)", "u1", "u2"] assert right.get_shapes(uni_cache, var_cache) == ["(V2, ...)"] assert output.get_shapes(uni_cache, var_cache) == ["(V3, ...)", 1, 2] diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index f3a8afc5..2c0880d1 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -29,7 +29,7 @@ Table, UniadicRecord, Variadic, - _get_summary_shapes, + get_summary_shapes, ) from mithril.framework.utils import define_unique_names from mithril.models import ( @@ -547,7 +547,7 @@ def test_extract_shapes_logical_1(): sub_model_name: sub_model.get_shapes(uni_cache, var_cache, False, False) for sub_model, sub_model_name in name_mappings.items() } - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) assert shape_info == { "Buffer_0": ({"input": [37, 23]}, {"output": [37, 23]}), "Buffer_1": ({"input": [37, 23]}, {"output": [37, 23]}), @@ -569,7 +569,7 @@ def test_extract_shapes_logical_2(): sub_model_name: sub_model.get_shapes(uni_cache, var_cache, False, False) for sub_model, sub_model_name in name_mappings.items() } - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) assert shape_info == { "Buffer_0": ({"input": [45, 96, 2]}, {"output": [45, 96, 2]}), "Buffer_1": ({"input": [45, 96, 2]}, {"output": [45, 96, 2]}), @@ -600,7 +600,7 @@ def test_extract_shapes_logical_3(): sub_model_name: sub_model.get_shapes(uni_cache, var_cache, symbolic=True) for sub_model, sub_model_name in name_mappings.items() } - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) assert shape_info == { "Linear_0": ( {"input": [4, "u1"], "w": ["u1", 4], "b": [4]}, @@ -637,7 +637,7 @@ def test_extract_shapes_logical_4(): sub_model_name: sub_model.get_shapes(uni_cache, var_cache, symbolic=False) for sub_model, sub_model_name in name_mappings.items() } - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) assert shape_info == { "Convolution2D_0": ( { @@ -702,7 +702,7 @@ def test_extract_shapes_logical_5(): sub_model_name: sub_model.get_shapes(uni_cache, var_cache, symbolic=True) for sub_model, sub_model_name in name_mappings.items() } - shape_info = _get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) assert shape_info == { "Linear_0": ( {"input": ["u1", "u2"], "w": ["u2", 4], "b": [4]}, From 913eeea30c325d19ab55a654ff1671bb41efc14d Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 25 Nov 2024 14:51:08 +0300 Subject: [PATCH 05/26] style: eye constraint bug is fixed, logical part of frameworks is finished --- benchmarks/speed_benchmarks/benchmark.py | 2 +- mithril/framework/common.py | 16 ++++- mithril/framework/constraints.py | 2 +- mithril/framework/logical/model.py | 80 +++++++++++++----------- mithril/framework/logical/primitive.py | 15 +++-- mithril/framework/physical/model.py | 4 +- mithril/models/train_model.py | 6 +- tests/scripts/test_summary.py | 10 +-- 8 files changed, 79 insertions(+), 56 deletions(-) diff --git a/benchmarks/speed_benchmarks/benchmark.py b/benchmarks/speed_benchmarks/benchmark.py index fdee645f..5c3c869c 100644 --- a/benchmarks/speed_benchmarks/benchmark.py +++ b/benchmarks/speed_benchmarks/benchmark.py @@ -156,5 +156,5 @@ ] ) -table._compile() +table.compile() table.display() diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 64a81019..b8bf4352 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -2844,7 +2844,7 @@ def _adjust_table(self): ) ) - def _compile( + def compile( self, row_sep: str | list[str] = " | ", col_sep: str = "-", @@ -3407,7 +3407,15 @@ def get_summary_shapes( return shape_info -def get_summary_types(name_mappings: dict, data_memo=None): +# TODO: Name mappings in there should more specialized as Any is not a good choice +# name mappings should be dict[BaseModel, str] +# data_memo should be dict[int, Tensor[Any] | Scalar] +# however, if this happens, there will be circular import problem +# So carrying this function to another module may be a better idea +# (maybe this function could be a method of BaseModel?) +def get_summary_types( + name_mappings: dict[Any, Any], data_memo: dict[Any, Any] | None = None +): if data_memo is None: data_memo = {} @@ -3433,7 +3441,9 @@ def get_summary_types(name_mappings: dict, data_memo=None): return type_info -def is_type_adjustment_required(data: dict[str, Tensor | Scalar], inputs: list[str]): +def is_type_adjustment_required( + data: dict[str, Tensor[Any] | Scalar], inputs: list[str] +): if len(inputs) <= 2: return False inputs = inputs[:2] diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index bb6f7193..4e5ed2cd 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -2883,7 +2883,7 @@ def eye_constraints(output: Tensor, N: Scalar, M: Scalar) -> ConstrainResultType updates.add(N) if m_valued and not m_uni_valued: - assert isinstance(M.value, int) + assert isinstance(M.value, int | NoneType) m_uni.set_value(M.value) updates.add(m_uni) elif m_uni_valued and not m_valued: diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 51338af5..a7d1ae18 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -41,6 +41,7 @@ ShapeTemplateType, Tensor, ToBeDetermined, + UniadicRecord, Updates, Variadic, get_summary, @@ -135,7 +136,9 @@ } -coercion_table: dict[tuple[str, type[Tensor] | type[Scalar]], type[PrimitiveModel]] = { +coercion_table: dict[ + tuple[str, type[Tensor[Any]] | type[Scalar]], type[PrimitiveModel] +] = { ("item", Tensor): TensorItem, ("item", Scalar): ScalarItem, ("slice", Tensor): TensorSlice, @@ -143,7 +146,7 @@ } type_conversion_map: dict[ - tuple[type[Tensor] | type[Scalar], type[Tensor] | type[Scalar]], + tuple[type[Tensor[Any]] | type[Scalar], type[Tensor[Any]] | type[Scalar]], type[ToTensor] | type[TensorToList] | None, ] = { (Scalar, Tensor): ToTensor, @@ -222,7 +225,7 @@ def set_outputs(self, *args: str | Connection, **kwargs: str | Connection) -> No else: # Named connections. # Create new output connection with given key name. - data: Tensor | Scalar = ( + data: Tensor[Any] | Scalar = ( Scalar(metadata.data._type) if isinstance(metadata.data, Scalar) else Tensor(metadata.data.shape, metadata.data._type) @@ -307,7 +310,7 @@ def _check_multi_write( # Note that Tensor type connections can not have any value # in logical models. if isinstance((data := local_connection.metadata.data), Scalar): - pair = [connection.metadata.data, data] + pair: list[Tensor[Any] | Scalar] = [connection.metadata.data, data] check_data, other = pair[local_input], pair[not local_input] if check_data.value is not TBD and check_data.value != other.value: raise ValueError("Multi-write detected for a valued input connection!") @@ -317,7 +320,7 @@ def _add_connection( model: BaseModel, local_key: str, given_connection: ConnectionType, - expose=None, + expose: bool | None = None, ) -> tuple[ConnectionData, Updates]: updates = Updates() outer_key, con_obj = None, None @@ -425,7 +428,7 @@ def _add_connection( return con_obj, updates def _unroll_template( - self, template: ExtendTemplate, joint_type: type[Tensor] | type[Scalar] + self, template: ExtendTemplate, joint_type: type[Tensor[Any]] | type[Scalar] ) -> ConnectionData: if template.output_connection is None: # Initialize all default init arguments of model as "..." other @@ -588,7 +591,7 @@ def merge_connections( @overload def handle_auto_conversion( self, - key_type: type[Tensor] | type[Scalar], + key_type: type[Tensor[Any]] | type[Scalar], is_input: bool, connection: IOKey | Connect | ConnectionData, updates: Updates, @@ -596,7 +599,7 @@ def handle_auto_conversion( @overload def handle_auto_conversion( # type: ignore[overload-cannot-match] # mypy import bug self, - key_type: type[Tensor] | type[Scalar], + key_type: type[Tensor[Any]] | type[Scalar], is_input: bool, connection: ConnectionType | tuple[ConnectionType, ...] | list[ConnectionType], updates: Updates, @@ -604,12 +607,13 @@ def handle_auto_conversion( # type: ignore[overload-cannot-match] # mypy import def handle_auto_conversion( self, - key_type: type[Tensor] | type[Scalar], + key_type: type[Tensor[Any]] | type[Scalar], is_input: bool, connection: ConnectionType | tuple[ConnectionType, ...] | list[ConnectionType], updates: Updates, ) -> ConnectionType: - connection_type: type[Tensor] | type[Scalar] | None = None + connection_type: type[Tensor[Any]] | type[Scalar] | None = None + data: Tensor[Any] | Scalar | None = None if isinstance(connection, MainValueInstance): connection_type = Scalar @@ -771,7 +775,7 @@ def handle_auto_conversion( else: # All connections in Connect object must have same type. - types: set[type[Scalar] | type[Tensor]] = { + types: set[type[Scalar] | type[Tensor[Any]]] = { _conn.metadata.data.__class__ for _conn in connections } if len(types) > 1: @@ -875,8 +879,8 @@ def handle_auto_conversion( @overload def create_connection_model( self, - connection_type: type[Tensor] | type[Scalar] | None, - key_type: type[Tensor] | type[Scalar], + connection_type: type[Tensor[Any]] | type[Scalar] | None, + key_type: type[Tensor[Any]] | type[Scalar], is_input: bool, connection: tuple[ConnectionType, ...] | list[ConnectionType], ) -> ConnectionData: ... @@ -886,8 +890,8 @@ def create_connection_model( @overload def create_connection_model( self, - connection_type: type[Tensor], - key_type: type[Tensor], + connection_type: type[Tensor[Any]], + key_type: type[Tensor[Any]], is_input: bool, connection: T, ) -> T: @@ -899,7 +903,7 @@ def create_connection_model( @overload def create_connection_model( # type: ignore[overload-cannot-match] # mypy import bug self, - connection_type: type[Tensor], + connection_type: type[Tensor[Any]], key_type: type[Scalar], is_input: bool, connection: ConnectionType | tuple[ConnectionType, ...] | list[ConnectionType], @@ -911,7 +915,7 @@ def create_connection_model( # type: ignore[overload-cannot-match] # mypy impor def create_connection_model( # type: ignore[overload-cannot-match] # mypy import bug self, connection_type: type[Scalar], - key_type: type[Tensor], + key_type: type[Tensor[Any]], is_input: bool, connection: ConnectionType | tuple[ConnectionType, ...] | list[ConnectionType], ) -> ConnectionData: @@ -934,15 +938,15 @@ def create_connection_model( # type: ignore[overload-cannot-match] # mypy impor def create_connection_model( self, connection_type: None, - key_type: type[Tensor] | type[Scalar], + key_type: type[Tensor[Any]] | type[Scalar], is_input: bool, connection: T, ) -> T: ... def create_connection_model( self, - connection_type: type[Tensor] | type[Scalar] | None, - key_type: type[Tensor] | type[Scalar], + connection_type: type[Tensor[Any]] | type[Scalar] | None, + key_type: type[Tensor[Any]] | type[Scalar], is_input: bool, connection: ConnectionType | tuple[ConnectionType, ...] | list[ConnectionType], ) -> ( @@ -1270,9 +1274,6 @@ def __add__(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: else (info, {}) ) - if not isinstance(model, BaseModel | PrimitiveModel): - raise TypeError("Added element should be a Model type.") - if ( model._canonical_input is not NOT_AVAILABLE and ( @@ -1306,7 +1307,12 @@ def __add__(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: @staticmethod def _update_key_name( - new_key, underscored_keys, raw_keys, key_mappings, key_origin, input_set + new_key: str, + underscored_keys: set[str], + raw_keys: dict[str, list[str]], + key_mappings: dict[str, str], + key_origin: str, + input_set: set[str], ) -> tuple[str, str]: # Add underscore if generated key name exists in input keys key_prefix = "_" @@ -1331,7 +1337,10 @@ def _update_key_name( return new_key, key_origin def _generate_keys( - self, symbolic=True, include_internals=True, include_outputs=False + self, + symbolic: bool = True, + include_internals: bool = True, + include_outputs: bool = False, ) -> dict[str, str]: key_mappings: dict[str, str] = {} raw_keys: dict[str, list[str]] = {} @@ -1392,7 +1401,7 @@ def _generate_keys( raw_keys, key_mappings, key_origin, - self._input_keys, + set(self._input_keys), ) new_key = key_origin + key_suffix @@ -1403,7 +1412,7 @@ def _generate_keys( raw_keys, key_mappings, key_origin, - self._input_keys, + set(self._input_keys), ) raw_keys[key_origin].append(key) key_mappings[key] = new_key @@ -1485,8 +1494,8 @@ def summary( symbolic: bool = False, name: str | None = None, alternative_shapes: bool = False, - uni_cache: dict | None = None, - var_cache: dict | None = None, + uni_cache: dict[UniadicRecord, str] | None = None, + var_cache: dict[Variadic, str] | None = None, depth: int = 0, ) -> None: if uni_cache is None: @@ -1494,11 +1503,12 @@ def summary( if var_cache is None: var_cache = {} - type_info = None + type_info: dict[str, tuple[dict[str, str], dict[str, str]]] | None = None shape_info = None # extract relevant information about summary dag = self.dag - name_mappings = define_unique_names(dag) + # TODO: Fix typing issues + name_mappings: dict[BaseModel, str] = define_unique_names(dag) # extract model topology conn_info = self.extract_connection_info(name_mappings) @@ -1525,12 +1535,12 @@ def summary( conns=conn_info, name=name, shape=shape_info, types=type_info ) - table._compile() + table.compile() table.display() if depth > 0: for model, model_name in name_mappings.items(): - kwargs = { + kwargs: dict[str, Any] = { "depth": depth - 1, "shapes": shapes, "symbolic": symbolic, @@ -1547,10 +1557,10 @@ def summary( def extract_connection_info( self, name_mappings: dict[BaseModel, str], - data_to_key_map: dict[Tensor | Scalar, list[str]] | None = None, + data_to_key_map: dict[Tensor[Any] | Scalar, list[str]] | None = None, data_memo: Mapping[int, Tensor[Any] | Scalar] | None = None, ): - conn_info: dict[str, tuple[dict, dict]] = {} + conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} if self._input_keys: if data_to_key_map is None: data_to_key_map = {} diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index d3cd083a..16000809 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -26,7 +26,9 @@ Scalar, Tensor, TensorType, + UniadicRecord, Updates, + Variadic, create_shape_map, get_summary, get_summary_shapes, @@ -59,7 +61,7 @@ def __init__( if isinstance(value, TensorType) } shapes = create_shape_map(shape_templates, self.constraint_solver) - data_set = set() + data_set: set[Tensor[Any]] = set() is_diff = False for key, value in kwargs.items(): if isinstance(value, TensorType): @@ -74,7 +76,6 @@ def __init__( else: self.conns.set_connection_type(conn_data, KeyType.INPUT) is_diff |= not value.is_non_diff - if isinstance(output_data, Tensor): output_data._differentiable = is_diff @@ -129,11 +130,13 @@ def __iadd__(self, other: BaseModel): ) @staticmethod - def convert_to_tuple(value: int | tuple[int, int] | list) -> tuple[int, int]: + def convert_to_tuple(value: int | tuple[int, int] | list[Any]) -> tuple[int, int]: if isinstance(value, int): new_value = (value, value) elif isinstance(value, list): new_value = tuple(value) + else: + new_value = value return new_value def extract_connection_info( @@ -178,8 +181,8 @@ def summary( symbolic: bool = False, name: str | None = None, alternative_shapes: bool = False, - uni_cache: dict | None = None, - var_cache: dict | None = None, + uni_cache: dict[UniadicRecord, str] | None = None, + var_cache: dict[Variadic, str] | None = None, ) -> None: if uni_cache is None: uni_cache = {} @@ -216,7 +219,7 @@ def summary( conns=conn_info, name=name, shape=shape_info, types=type_info ) - table._compile() + table.compile() table.display() def _freeze(self) -> None: diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 92315c25..16928078 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -957,7 +957,7 @@ def _print_model_info( pm_info, right_length=1, left_length=18, len_space=1, r_len=100 )[:-1] info_table.add_row([info]) - info_table._compile() + info_table.compile() info_table.display() def summary( @@ -1051,7 +1051,7 @@ def summary( params=param_info, ) - table._compile() + table.compile() table.display() if depth > 0: for model, model_name in name_mappings.items(): diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index 7442b6a2..d8de72db 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -586,7 +586,7 @@ def summary( t_list.append([reduce_str[:-2]]) t_list.append([str(loss_dict["coef"])]) loss_table.add_row(t_list) - loss_table._compile(row_sep=[" | ", " | ", " | ", " | ", " | "]) + loss_table.compile(row_sep=[" | ", " | ", " | ", " | ", " | "]) loss_table.display() if self.geomean_map: @@ -619,7 +619,7 @@ def summary( r_list.append(str(shape[updated_reg_key])) r_list.append([str(coef)]) reg_table.add_row(r_list) - reg_table._compile(row_sep=[" | ", " | ", " | "]) + reg_table.compile(row_sep=[" | ", " | ", " | "]) reg_table.display() if self.metric_keys: @@ -647,7 +647,7 @@ def summary( m_list.append([val[0] for val in conns.values()]) m_list.append([val[0] for val in out_conn.values()]) metric_table.add_row(m_list) - metric_table._compile(row_sep=[" | ", " | ", " | ", " | "]) + metric_table.compile(row_sep=[" | ", " | ", " | ", " | "]) metric_table.display() def _add_geo_mean(self): diff --git a/tests/scripts/test_summary.py b/tests/scripts/test_summary.py index 2c0880d1..f0e7b125 100644 --- a/tests/scripts/test_summary.py +++ b/tests/scripts/test_summary.py @@ -786,7 +786,7 @@ def test_table_1(): table.add_header(headers) for list in list_1: table.add_row(list) - table._compile(row_sep=" ") + table.compile(row_sep=" ") cells = table.cell_str n_of_rows = cells.count("\n") - 2 @@ -802,7 +802,7 @@ def test_table_2(): table.add_header(headers) for list in list_1: table.add_row(list) - table._compile(row_sep=" ") + table.compile(row_sep=" ") cells = table.cell_str n_of_rows = cells.count("\n") - 2 assert n_of_rows == 4 @@ -817,7 +817,7 @@ def test_table_3(): table.add_header(headers) for list in list_1: table.add_row(list) - table._compile(row_sep=" ") + table.compile(row_sep=" ") cells = table.cell_str n_of_rows = cells.count("\n") - 2 assert n_of_rows == 4 @@ -834,7 +834,7 @@ def test_table_4(): table.add_header(subheaders) for list in list_1: table.add_row(list) - table._compile(row_sep=" | ") + table.compile(row_sep=" | ") cells = table.cell_str n_of_rows = cells.count("\n") - 2 assert n_of_rows == 4 @@ -851,7 +851,7 @@ def test_table_5(): table.add_header(subheaders) for list in list_1: table.add_row(list) - table._compile(row_sep=" | ") + table.compile(row_sep=" | ") cells = table.cell_str n_of_rows = cells.count("\n") - 2 assert n_of_rows == 12 From f88a7bb4fe28e3d4f0ca100f1685e4504ba80659 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Wed, 27 Nov 2024 13:11:42 +0300 Subject: [PATCH 06/26] fix: pypi errors are fixed in logical, physical, common, constraints, models, and utils files / directories --- mithril/__init__.py | 2 +- mithril/backends/backend.py | 6 +- .../with_autograd/jax_backend/backend.py | 2 +- .../with_autograd/mlx_backend/backend.py | 2 +- .../with_autograd/torch_backend/backend.py | 2 +- .../with_manualgrad/c_backend/backend.py | 2 +- .../with_manualgrad/numpy_backend/backend.py | 2 +- mithril/framework/codegen/__init__.py | 4 +- mithril/framework/codegen/c_ast.py | 2 +- mithril/framework/codegen/c_gen.py | 3 +- mithril/framework/codegen/code_gen.py | 3 +- mithril/framework/common.py | 106 ++++++++++-------- mithril/framework/constraints.py | 94 ++++++++-------- mithril/framework/logical/primitive.py | 1 + mithril/framework/physical/data_store.py | 22 ++-- mithril/framework/physical/flat_graph.py | 37 +++--- mithril/framework/physical/model.py | 81 ++++++------- mithril/framework/utils.py | 40 +++---- mithril/models/models.py | 24 ++-- mithril/models/primitives.py | 20 ++-- mithril/models/train_model.py | 46 ++++---- mithril/utils/func_utils.py | 7 +- mithril/utils/type_utils.py | 22 ++-- .../test_compile_keys_consistencies.py | 22 ++-- tests/scripts/test_type_consistencies.py | 2 +- 25 files changed, 292 insertions(+), 262 deletions(-) diff --git a/mithril/__init__.py b/mithril/__init__.py index 87408f76..269b3331 100644 --- a/mithril/__init__.py +++ b/mithril/__init__.py @@ -98,7 +98,7 @@ def compile( model: BaseModel, backend: Backend[DataType], *, - constant_keys: PhysicalConstantType | None = None, + constant_keys: PhysicalConstantType[DataType] | None = None, data_keys: Iterable[str | Connection] | None = None, discard_keys: Iterable[str | Connection] | None = None, jacobian_keys: Iterable[str | Connection] | None = None, diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index 1161a73b..b722d054 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -38,8 +38,8 @@ class Backend(ABC, Generic[DataType]): is_installed = True _device: str _precision: int - primitive_function_dict: dict[str, Callable] - registered_primitives: dict[str, Callable] + primitive_function_dict: dict[str, Callable[..., DataType]] + registered_primitives: dict[str, Callable[..., DataType]] array_creation_funcs: list[str] primitive_fn_path: str @@ -77,7 +77,7 @@ def e(self): return math.e @property - def is_manualgrad(self): + def is_manualgrad(self) -> bool: raise NotImplementedError("is_manualgrad is not implemented") def get_backend_array_type(self): # noqa: B902 diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index f04ae2e2..6a29ee7d 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -72,7 +72,7 @@ def __init__( self.prng_key = jax.random.PRNGKey(self.seed) @property - def is_manualgrad(self): + def is_manualgrad(self) -> bool: return False @property diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index 91e786fc..3689da7c 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -49,7 +49,7 @@ def __init__( mx.random.seed(self.seed) @property - def is_manualgrad(self): + def is_manualgrad(self) -> bool: return False @property diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index b350e64b..cd985be2 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -71,7 +71,7 @@ def __init__( torch.random.manual_seed(self.seed) @property - def is_manualgrad(self): + def is_manualgrad(self) -> bool: return False @property diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index fbc64706..c42758d5 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -32,7 +32,7 @@ def __init__(self): self._device = "cpu" @property - def is_manualgrad(self): + def is_manualgrad(self) -> bool: return True def set_seed(self, seed: int): diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index b4dcf4a2..45a5c27d 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -60,7 +60,7 @@ def __init__(self, device: str = "cpu", precision: int = 32) -> None: np.random.seed(self.seed) @property - def is_manualgrad(self): + def is_manualgrad(self) -> bool: return True @property diff --git a/mithril/framework/codegen/__init__.py b/mithril/framework/codegen/__init__.py index fbaa91b3..9f7c3fbf 100644 --- a/mithril/framework/codegen/__init__.py +++ b/mithril/framework/codegen/__init__.py @@ -13,11 +13,13 @@ # limitations under the License. +from typing import Any + from ...backends.backend import Backend from .code_gen import CodeGen from .python_gen import PythonCodeGen -code_gen_map: dict[type[Backend], type[CodeGen]] = {} +code_gen_map: dict[type[Backend[Any]], type[CodeGen]] = {} try: from ...backends.with_autograd.jax_backend import JaxBackend diff --git a/mithril/framework/codegen/c_ast.py b/mithril/framework/codegen/c_ast.py index 5b99a940..81da9aa4 100644 --- a/mithril/framework/codegen/c_ast.py +++ b/mithril/framework/codegen/c_ast.py @@ -20,7 +20,7 @@ @dataclass class AST(ABC): @abstractmethod - def to_str(self): + def to_str(self) -> str: raise NotImplementedError("to_str is not implemented") diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index b6b58b6e..643ebeb8 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -17,6 +17,7 @@ import subprocess import tempfile from functools import partial +from typing import Any from ...backends.with_manualgrad.c_backend import CBackend, backend from ...backends.with_manualgrad.c_backend.src import array @@ -33,7 +34,7 @@ class CGen(CodeGen): BACKWARD_FN_SUFFIX = "_grad" - def __init__(self, pm: PhysicalModel) -> None: + def __init__(self, pm: PhysicalModel[Any]) -> None: super().__init__(pm) assert isinstance(self.pm.backend, CBackend) diff --git a/mithril/framework/codegen/code_gen.py b/mithril/framework/codegen/code_gen.py index adecbb51..eccd03e7 100644 --- a/mithril/framework/codegen/code_gen.py +++ b/mithril/framework/codegen/code_gen.py @@ -14,12 +14,13 @@ from abc import ABC, abstractmethod from collections.abc import Callable +from typing import Any from ..physical.model import PhysicalModel class CodeGen(ABC): - def __init__(self, pm: PhysicalModel) -> None: + def __init__(self, pm: PhysicalModel[Any]) -> None: self.pm = pm self.code: str | None = None self.file_path: str | None = None diff --git a/mithril/framework/common.py b/mithril/framework/common.py index b8bf4352..1e50b16d 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -14,14 +14,14 @@ from __future__ import annotations -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Iterator, Mapping, Sequence from copy import copy, deepcopy from dataclasses import dataclass, field from enum import Enum from functools import partial, reduce from itertools import combinations, cycle, product, zip_longest from types import EllipsisType, GenericAlias, UnionType -from typing import Any, Literal, TypeVar +from typing import Any, Literal, TypeVar, overload from ..backends.backend import Backend from ..core import ( @@ -89,7 +89,7 @@ class MySingletonObject(SingletonObject): assert obj1 is obj2 # True, both are the same instance """ - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any): if cls._instance is None: cls._instance = super().__new__(cls, *args, **kwargs) return cls._instance @@ -155,11 +155,10 @@ class KeyType(Enum): | int | tuple[int, ...] | list[int] - | dict + | dict[Any, Any] | slice | Constant | tuple[int | None, ...] - | dict | str ) type DeferredValueType = ( @@ -344,7 +343,7 @@ def _reduce_uniadic_referees(self, updates: Updates): rec.referees = {uni} @staticmethod - def _find_intersection_reprs(repr1): + def _find_intersection_reprs(repr1: ShapeRepr) -> set[ShapeRepr]: intersection_reprs: set[ShapeRepr] = set() # Collect visited repr's symbols. symbols = repr1.prefix + repr1.suffix @@ -364,7 +363,11 @@ def _find_intersection_reprs(repr1): return intersection_reprs @staticmethod - def _add_sublists(repr1, intersection_reprs, deletion_nodes): + def _add_sublists( + repr1: ShapeRepr, + intersection_reprs: set[ShapeRepr], + deletion_nodes: dict[ShapeNode, set[ShapeNode]], + ) -> Updates: updates = Updates() for repr2 in intersection_reprs: if (repr1.node != repr2.node) and (repr1 in repr2): @@ -525,14 +528,14 @@ def _get_shapes( symbolic: bool = True, verbose: bool = False, key_mappings: dict[str, str] | None = None, -) -> _ShapesType: +) -> dict[str, ShapeTemplateType | list[ShapeTemplateType]]: if key_mappings is None: key_mappings = {} if uniadic_keys is None: uniadic_keys = {} if varadic_keys is None: varadic_keys = {} - shapes: dict = {} + shapes: dict[str, ShapeTemplateType | list[ShapeTemplateType] | None] = {} for key, data in data_dict.items(): key_name = key_mappings.get(key, key) if isinstance(data, Tensor): @@ -541,7 +544,7 @@ def _get_shapes( ) else: shapes[key_name] = None - return shapes + return shapes # type: ignore class BaseData: @@ -661,7 +664,7 @@ def match(self, other: Tensor[Any] | Scalar) -> Updates: class Tensor(BaseData, GenericDataType[DataType]): _type: type[float] | type[int] | type[bool] | UnionType temp_value: TensorValueType | ToBeDetermined - value: DataType | ToBeDetermined + value: DataType | ToBeDetermined | None def __init__( self, @@ -696,14 +699,14 @@ def is_valued(self) -> bool: def _set_as_physical(self): super()._set_as_physical() - def _convert_value(self, backend: Backend[Any]) -> DataType | ToBeDetermined: + def _convert_value(self, backend: Backend[DataType]) -> DataType | ToBeDetermined: if isinstance(self.temp_value, Constant): self.value = backend.array( epsilon_table[backend.precision][self.temp_value] ) elif self.temp_value is not TBD: self.value = backend.array(self.temp_value) - return self.value + return self.value # type: ignore def make_physical( self, backend: Backend[Any], memo: dict[int, Tensor[Any] | Scalar] @@ -1215,7 +1218,7 @@ def __hash__(self) -> int: | Mapping[str, ShapeTemplateType] | Mapping[Connection, ShapeTemplateType] ) -_ShapesType = Mapping[str, ShapeTemplateType] +_ShapesType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType]] @dataclass @@ -1239,7 +1242,7 @@ def __hash__(self) -> int: def __and__(self, other) -> Connect: return Connect(self.conn) & other - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: return id(self) == id(other) def set_differentiable(self, differentiable: bool = True) -> None: @@ -1686,7 +1689,7 @@ def update_dnf(self) -> bool: def get_all_uniadics(self) -> set[Uniadic]: # TODO: add all ANDs Uniadics into lookup table, then remove this method and # directly call self.lookup_table.keys() - uniadics = set() + uniadics: set[Uniadic] = set() for dnf in self.dnf_list: for item in dnf.item_list: for key, value in item.uni_table.items(): @@ -1879,7 +1882,7 @@ def _match_possibles( updates = Updates() # Clip self.possibles with new_prefix and new_suffix # Add clipped uniadics to new equivalences. - possibles = [] + possibles: list[PossibleValues] = [] for _len, pos in self.possibles.items(): if _len < len(prefix) + len(suffix): continue @@ -2301,7 +2304,7 @@ def get_shapes( v_keys: dict[Variadic, str] | None = None, symbolic: bool = True, verbose: bool = False, - ) -> list[int | str | None] | list[list[int | str | None]]: + ) -> ShapeTemplateType | list[ShapeTemplateType]: if u_keys is None: u_keys = {} if v_keys is None: @@ -2344,14 +2347,14 @@ def get_most_informative_repr(self) -> ShapeRepr: ) # Count number of used reprs - n_reprs = set() + n_reprs: set[ShapeRepr] = set() for uni in repr.prefix + repr.suffix: n_reprs |= uni.reprs if repr.root is not None: n_reprs |= repr.root.reprs - best_n_reprs = set() + best_n_reprs: set[ShapeRepr] = set() for uni in most_informative_repr.prefix + most_informative_repr.suffix: best_n_reprs |= uni.reprs @@ -2471,18 +2474,16 @@ def __contains__(self, key: ShapeRepr) -> bool: ) or self._is_subset_rootless(self.suffix, key.prefix) return False - def __getitem__(self, position): + def __getitem__(self, position: int): # TODO: Currently position could only be int, but we should support slicing # operations too (e.g. repr[:2]) if it is possible (if index of Variadic # field allows the operation). - if not isinstance(position, int): - raise ValueError("Requires int index!") if position < 0 and self.root is not None: return self.suffix[position] else: return self.prefix[position] - def __setitem__(self, position, new_item): + def __setitem__(self, position: int, new_item: Uniadic): if position < 0 and self.root is not None: self.suffix[position] = new_item else: @@ -2518,7 +2519,7 @@ def get_shapes( u_keys: dict[UniadicRecord, str] | None = None, v_keys: dict[Variadic, str] | None = None, symbolic: bool = True, - ) -> list[int | str | None]: + ) -> ShapeTemplateType: if u_keys is None: u_keys = {} if v_keys is None: @@ -2699,12 +2700,12 @@ def clear(self): # all([uni.value is not None for uni in repr.prefix]) @dataclass class Constraint: - fn: Callable + fn: ConstraintFunctionType type: UpdateType = UpdateType.SHAPE call_counter: int = 0 - post_processes: set[Callable] = field(default_factory=lambda: set()) + post_processes: set[ConstraintFunctionType] = field(default_factory=lambda: set()) - def __call__(self, keys: list[Scalar | Tensor]) -> ConstrainResultType: + def __call__(self, keys: list[Scalar | Tensor[Any]]) -> ConstrainResultType: status = False updates = Updates() if self.type == UpdateType.SHAPE: @@ -2725,7 +2726,7 @@ def add_post_process(self, fn: ConstraintFunctionType): self.post_processes.add(fn) def create_post_constraints(self): - constraints = set() + constraints: set[Constraint] = set() for fn in self.post_processes: constraints.add(Constraint(fn, self.type)) return constraints @@ -2734,8 +2735,16 @@ def __hash__(self) -> int: return hash(id(self)) -def add_lens(x, y, const): - return x + next(const) + y +@overload +def add_lens(x: int, y: int, const: Iterator[int]) -> int: ... + + +@overload +def add_lens(x: str, y: str, const: Iterator[str]) -> str: ... + + +def add_lens(x: str | int, y: str | int, const: Iterator[str | int]) -> str | int: + return x + next(const) + y # type: ignore type RowColumnType = list[str | list[str]] | list[str] | list[list[str]] @@ -2760,6 +2769,7 @@ def add_row(self, row: RowColumnType): self.cells.append(row) def add_column(self, column: RowColumnType): + idx = 0 for idx, row in enumerate(column[: len(self.headers)]): self.headers[idx].append(row) # type: ignore @@ -2777,12 +2787,12 @@ def _calculate_table_specs(self): # Initialize cell_heights and cell_widths lists, these # lists will hold minimum height and width that every cell # can have respectively. - cell_heights = [] - cell_widths = [] + cell_heights: list[list[int]] = [] + cell_widths: list[list[int]] = [] for row in all_elems: - row_heights = [] - row_widths = [] + row_heights: list[int] = [] + row_widths: list[int] = [] for cell in row: if isinstance(cell, list): # if cell is a list, length of the list will give minimum height @@ -2874,16 +2884,18 @@ def compile( # adjust the table accordingly self._adjust_table() # calculate total table width - table_width = reduce( + table_width = reduce( # type: ignore partial(add_lens, const=(len(row) for row in row_sep)), self.each_row_width ) - table_constructor_fn = partial(add_lens, const=cycle(row_sep)) - table_constructor_fn_w_spaces = partial( + table_constructor_fn: Callable[[str, str], str] = partial( + add_lens, const=cycle(row_sep) + ) # type: ignore + table_constructor_fn_w_spaces: Callable[[str, str], str] = partial( # type: ignore add_lens, const=cycle(len(row) * " " for row in row_sep) ) end = "\n" - header_list = [] - cell_list = [] + header_list: list[str] = [] + cell_list: list[str] = [] # Construct the header if it exists header_list.append(self.name.center(table_width) + end) @@ -2968,7 +2980,7 @@ def construct_subtable_str( @staticmethod def fill_spaces( - value: list | str, + value: list[str] | str, max_width: int = 0, max_height: int = 0, align: Literal["left", "right", "center"] = "left", @@ -3034,7 +3046,7 @@ def adjust_list(strings: list[str], max_len: int = 0) -> list[str]: """ if not strings: # Fill empty strings with None - new_list = ["None"] + new_list: list[str] = ["None"] else: line_len = 0 new_str = "" @@ -3052,7 +3064,7 @@ def adjust_list(strings: list[str], max_len: int = 0) -> list[str]: @staticmethod def dict_to_table( - in_dict: Mapping[str, list | dict], + in_dict: Mapping[str, list[str] | dict[str, Any]], seperator: str = " : ", left_align: Literal["left", "right", "center"] = "left", right_align: Literal["left", "right", "center"] = "left", @@ -3060,7 +3072,7 @@ def dict_to_table( right_length: int = 0, len_space: int = 0, space_fill: str = "-", - r_len=0, + r_len: int = 0, ) -> list[str]: """takes a dicts and creates a list of strings from that dict by filling empty spaces and making necessary alignments This function will work recursivey of @@ -3102,7 +3114,7 @@ def dict_to_table( list[str]: list of strings with same length """ - table = [] + table: list[str] = [] left_list: list[str] = [] right_list: list[str] = [] @@ -3145,7 +3157,7 @@ def dict_to_table( def create_shape_map( - shape_template: _ShapesType, + shape_template: Mapping[str, ShapeTemplateType], solver: ConstraintSolver, ) -> dict[str, ShapeRepr]: used_keys: UsedKeysType = {} @@ -3419,7 +3431,7 @@ def get_summary_types( if data_memo is None: data_memo = {} - type_info: dict[str, tuple[dict, dict]] = {} + type_info: dict[str, tuple[dict[str, str], dict[str, str]]] = {} for model, model_name in name_mappings.items(): in_dict, out_dict = type_info.setdefault(model_name, ({}, {})) diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index 4e5ed2cd..fbb9042a 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -102,18 +102,14 @@ # Below functions are used in various constraints. -def prod_fn(a, b): - return (a if isinstance(a, int) else a.value) * ( +def prod_fn(a: int | Uniadic, b: int | Uniadic) -> int: + return (a if isinstance(a, int) else a.value) * ( # type: ignore b if isinstance(b, int) else b.value ) -def is_repr_known(repr) -> bool: - return ( - repr.root is None - and repr.prefix - and all([uni.value is not None for uni in repr.prefix]) - ) +def is_repr_known(repr: ShapeRepr) -> bool: + return repr.root is None and all([uni.value is not None for uni in repr.prefix]) def create_union_type( @@ -159,7 +155,7 @@ def general_tensor_type_constraint(*args: Scalar | Tensor[Any]): output, *inputs = args arg_types: set[type | UnionType | NestedListType | GenericAlias] = set() all_possible_types: set[type | UnionType | NestedListType | GenericAlias] = set() - union_types: set[tuple[Scalar | Tensor, UnionType]] = set() + union_types: set[tuple[Scalar | Tensor[Any], UnionType]] = set() # Set all different types and also Union types in input args. for arg in inputs: typ = arg._type @@ -1456,7 +1452,7 @@ def reduce_constraints( # Try to infer output shape structure from input shape structure. # First initialize out_prefix and out_suffix with the Uniadics # which may be transferred to the output. - out_prefix = [] + out_prefix: list[Uniadic] = [] for idx, uni in enumerate(input_shape.prefix): if idx not in axis_val: if not neg_idx or idx < (len(input_shape) - neg_idx): @@ -1466,7 +1462,7 @@ def reduce_constraints( elif replacement: out_prefix.append(replacement) - out_suffix = [] + out_suffix: list[Uniadic] = [] for idx, uni in enumerate(input_shape.suffix): if (idx - len(input_shape.suffix)) not in axis_val: if not positive_axes or ( @@ -1578,7 +1574,7 @@ def reduce_constraints( ) axis_val = tuple([idx if idx > 0 else idx + in_rank for idx in axis_val]) out_iter = iter(output_shape.prefix) - input_uniadics = [] + input_uniadics: list[Uniadic] = [] for idx in range(in_rank): if idx in axis_val: input_uniadics.append(Uniadic()) @@ -1628,7 +1624,7 @@ def concat_constraints( # is negative var = Variadic() if axis_val >= 0: - uniadics = [Uniadic() for _ in range(axis_val)] + uniadics: list[Uniadic] = [Uniadic() for _ in range(axis_val)] for repr in reprs: updates |= repr.inner_match(prefix=uniadics + [Uniadic()], root=var) elif axis_val < 0: @@ -1644,7 +1640,9 @@ def concat_constraints( # values at axis. shape formula of output of axis must be out = # sum(all ins). Therefore, if there is only one unknown, we can # infer unknown uniadic's shape by algebra. - uniadics, uniadic_values, pruned_uni_values = [], [], [] + uniadics = [] + uniadic_values: list[int | None] = [] + pruned_uni_values: list[int] = [] for repr in reprs: if ( repr.root is None @@ -1687,7 +1685,7 @@ def concat_constraints( else: dividing_factor = 1 substract_factor = 0 - none_values = [] + none_values: list[Uniadic] = [] for key in keys: if key.root is None: unis_without_value = [ @@ -1765,7 +1763,7 @@ def reverse_constraints( a_val: list[int] | tuple[int, ...] = ( [axes_val] if isinstance(axes_val, int) else axes_val ) - in_unis = [Uniadic() for idx in range(len(a_val))] + in_unis = [Uniadic() for _ in range(len(a_val))] out_unis = [in_unis[axis] for axis in a_val] updates |= input_shape._update_uniadics(input_shape.prefix, in_unis) @@ -1824,7 +1822,7 @@ def polynomial_features_constraints( elif ( input_uniadic.value is None and output_uniadic.value is not None - and degree is not None + and degree_val is not None ): # Increment input dimensionality by one up to # satisfying the equation: (dim + degree).(dim + degree - 1)....(dim + 1) = @@ -1939,12 +1937,12 @@ def sliding_window_1d_constraints( def conv_1d_constraints( - output: Tensor, - input: Tensor, + output: Tensor[Any], + input: Tensor[Any], stride: Scalar, padding: Scalar, dilation: Scalar, - kernel: Tensor, + kernel: Tensor[Any], ) -> ConstrainResultType: updates = Updates() status = False @@ -2002,8 +2000,8 @@ def conv_1d_constraints( # TODO: Change name (Conv also uses the constraint below) def sliding_window_2d_constraints( - output: Tensor, - input: Tensor, + output: Tensor[Any], + input: Tensor[Any], stride: Scalar, padding: Scalar, dilation: Scalar, @@ -2077,12 +2075,12 @@ def sliding_window_2d_constraints( def conv_2d_constraints( - output: Tensor, - input: Tensor, + output: Tensor[Any], + input: Tensor[Any], stride: Scalar, padding: Scalar, dilation: Scalar, - kernel: Tensor, + kernel: Tensor[Any], ) -> ConstrainResultType: status = False updates = Updates() @@ -2159,11 +2157,10 @@ def conv_2d_constraints( def flatten_constrains( - output: Tensor, input: Tensor, start_dim: Scalar, end_dim: Scalar + output: Tensor[Any], input: Tensor[Any], start_dim: Scalar, end_dim: Scalar ) -> ConstrainResultType: status = False updates = Updates() - new_shape_items = set() assert input._temp_shape is not None, "Input shape of Flatten is not set!" assert output._temp_shape is not None, "Output shape of Flatten is not set!" input_shape: ShapeRepr = input._temp_shape @@ -2296,14 +2293,13 @@ def flatten_constrains( suffix = input_shapes[end_dim_val + 1 :] if end_dim_val != -1 else [] prefix = input_shapes[:start_dim_val] updates |= output_shape.inner_match( - prefix=prefix + [(new_uni := Uniadic(prod))] + suffix + prefix=prefix + [Uniadic(prod)] + suffix ) - new_shape_items.add(new_uni) return status, updates def where_constrains( - output: Tensor, cond: Tensor, input1: Tensor, input2: Tensor + output: Tensor[Any], cond: Tensor[Any], input1: Tensor[Any], input2: Tensor[Any] ) -> ConstrainResultType: # TODO: Find a way to implement this constraint without creating a Tensor and # ShapeRepr @@ -2313,10 +2309,8 @@ def where_constrains( assert input2._temp_shape is not None, "Input2 shape of Where is not set!" status = False updates = Updates() - new_shape_items = set() - broadcast_shp = ShapeRepr(root=(new_var := Variadic())) - new_shape_items.add(new_var) + broadcast_shp = ShapeRepr(root=Variadic()) _, local_updates = bcast_helper( broadcast_shp, input1._temp_shape, input2._temp_shape, 0 @@ -2330,7 +2324,7 @@ def where_constrains( def arange_constraints( - output: Tensor, start: Scalar, stop: Scalar, step: Scalar + output: Tensor[Any], start: Scalar, stop: Scalar, step: Scalar ) -> ConstrainResultType: assert output._temp_shape is not None, "Output shape of Arange is not set!" output_shape: ShapeRepr = output._temp_shape @@ -2418,7 +2412,7 @@ def arange_constraints( def broadcast_to_constraints( - output: Tensor, shape: Scalar, input: Tensor + output: Tensor[Any], shape: Scalar, input: Tensor[Any] ) -> ConstrainResultType: status = False updates = Updates() @@ -2593,8 +2587,9 @@ def reshape_constraints( # Try to infer shape value. elif is_repr_known(output_shape): - if is_repr_known(input_shape) and reduce(prod_fn, input_shape.prefix) != reduce( - prod_fn, output_shape.prefix + if is_repr_known(input_shape) and reduce(prod_fn, input_shape.prefix) != reduce( # type: ignore + prod_fn, # type: ignore + output_shape.prefix, ): out_shape = tuple(uni.value for uni in output_shape.prefix) in_shape = tuple(uni.value for uni in input_shape.prefix) @@ -2620,7 +2615,7 @@ def reshape_constraints( return status, updates -def squeeze_constraints(output: Tensor, input: Tensor) -> ConstrainResultType: +def squeeze_constraints(output: Tensor[Any], input: Tensor[Any]) -> ConstrainResultType: updates = Updates() assert input._temp_shape is not None, "Input shape of Squeeze is not set!" assert output._temp_shape is not None, "Output shape of Squeeze is not set!" @@ -2652,7 +2647,8 @@ def squeeze_constraints(output: Tensor, input: Tensor) -> ConstrainResultType: # For example: input -> [4, Var, 2, u], output -> [4, 2], then # u = 1 ... - new_prefix, new_suffix = [], [] + new_prefix: list[Uniadic] = [] + new_suffix: list[Uniadic] = [] variadic_required = False for uni in input_shape.prefix: @@ -2664,7 +2660,7 @@ def squeeze_constraints(output: Tensor, input: Tensor) -> ConstrainResultType: # If Variadic input, iterate over reverse suffix else # reverse prefix. - reverse_uni_list = list() + reverse_uni_list: list[Uniadic] = list() for uni in ( input_shape.suffix[::-1] if input_shape.root is not None @@ -2789,14 +2785,14 @@ def size_constraints( max_pos_dim = max(pos_dims) + 1 if pos_dims else 0 max_neg_dim = -min(neg_dims) if neg_dims else 0 - input_prefix = [] + input_prefix: list[Uniadic] = [] for idx, _ in enumerate(range(max_pos_dim)): if len(input_shape.prefix) > idx: input_prefix.append(input_shape.prefix[idx]) else: input_prefix.append(Uniadic()) - input_suffix = [] + input_suffix: list[Uniadic] = [] rev_suffix = input_shape.suffix[::-1] for idx, _ in enumerate(range(max_neg_dim)): if len(rev_suffix) > idx: @@ -2864,7 +2860,7 @@ def shape_constraints(output: Scalar, input: Tensor[Any]) -> ConstrainResultType return status, updates -def eye_constraints(output: Tensor, N: Scalar, M: Scalar) -> ConstrainResultType: +def eye_constraints(output: Tensor[Any], N: Scalar, M: Scalar) -> ConstrainResultType: updates = Updates() assert output._temp_shape is not None, "Output shape of Eye is not set!" output_shape: ShapeRepr = output._temp_shape @@ -2894,7 +2890,7 @@ def eye_constraints(output: Tensor, N: Scalar, M: Scalar) -> ConstrainResultType def swap_axes_constraints( - output: Tensor, input: Tensor, axis1: Scalar, axis2: Scalar + output: Tensor[Any], input: Tensor[Any], axis1: Scalar, axis2: Scalar ) -> ConstrainResultType: assert input._temp_shape is not None, "Input shape of SwapAxes is not set!" assert output._temp_shape is not None, "Output shape of SwapAxes is not set!" @@ -3008,10 +3004,14 @@ def swap_axes_constraints( # If only one of the axes are given. Find the given axis. # create uniadics with the same amount of this axis and match it # with input + given_axis: int | None = None if not isinstance(axis1_val, ToBeDetermined): given_axis = axis1_val elif not isinstance(axis2_val, ToBeDetermined): given_axis = axis2_val + assert isinstance(given_axis, int) + + unis: list[Uniadic] = [] if given_axis >= 0: unis = [Uniadic() for _ in range(given_axis + 1)] elif given_axis < 0: @@ -3094,10 +3094,9 @@ def tensor_to_list_constraints( elif not isinstance(output_val, ToBeDetermined) and not isinstance( output_val, NestedListType ): + shape: list[Uniadic] = [] if isinstance(output_value, list | tuple): shape = [Uniadic(idx) for idx in list_shape(list(output_val))] - elif isinstance(output_value, float | int): - shape = [] updates |= input_shape.inner_match(prefix=shape) status = True @@ -3420,7 +3419,6 @@ def padding_2d_constraint( raise RuntimeError( "'same' padding is not supported when the kernel size is even!" ) - _padding = (kernel_size[0] // 2, kernel_size[1] // 2) elif isinstance(kernel_size, int): if kernel_size % 2 == 0: raise RuntimeError( @@ -3438,7 +3436,7 @@ def padding_2d_constraint( for p in input_value: if isinstance(p, int): updated_padding.append((p, p)) - elif isinstance(input_value, Sequence) and len(p) == 2: + elif len(p) == 2: updated_padding.append(tuple(p)) else: raise RuntimeError(f"Given padding '{input_value}' is not valid!") diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 16000809..c3102474 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -63,6 +63,7 @@ def __init__( shapes = create_shape_map(shape_templates, self.constraint_solver) data_set: set[Tensor[Any]] = set() is_diff = False + output_data: Tensor[Any] | Scalar | None = None for key, value in kwargs.items(): if isinstance(value, TensorType): value = value.construct(shapes[key].node) diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index f238e6ff..883e1541 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -38,23 +38,25 @@ class StaticDataStore(GenericDataType[DataType]): def __init__( self, - graph: FlatGraph, - backend: Backend, + graph: FlatGraph[DataType], + backend: Backend[DataType], inference: bool, solver: ConstraintSolver, - memo: dict | None = None, + memo: dict[int, Tensor[DataType] | Scalar] | None = None, ) -> None: if memo is None: memo = {} self.is_materialized = False - self._all_data: dict[str, Tensor | Scalar] = dict() - self.data_memo: dict[int, Tensor | Scalar] = dict() - self.graph = graph + self._all_data: dict[str, Tensor[DataType] | Scalar] = dict() + self.data_memo: dict[int, Tensor[DataType] | Scalar] = dict() + self.graph: FlatGraph[DataType] = graph self.backend: Backend[DataType] = backend self.inference = inference - self._cached_data: dict[str, Tensor | Scalar] = dict() - self._intermediate_non_differentiables: BiMap[str, Tensor | Scalar] = BiMap() + self._cached_data: dict[str, Tensor[DataType] | Scalar] = dict() + self._intermediate_non_differentiables: BiMap[ + str, Tensor[DataType] | Scalar + ] = BiMap() self._runtime_static_keys: set[str] = set() self._unused_keys: set[str] = set() # Final tensor values of data store. @@ -127,7 +129,7 @@ def _clear_constraints(self, key: str): def _update_cached_data(self, updated_data: Updates) -> set[str]: # If any data value is found by shape inference algorithms # transfer this data in cached_data. - transferred_keys = set() + transferred_keys: set[str] = set() updated_inter_data = ( updated_data.value_updates & self._intermediate_non_differentiables.inverse.keys() @@ -200,7 +202,7 @@ def set_shapes( for key in new_statics: self._infer_unused_keys(key) - def update_data(self, data: dict[str, Tensor | Scalar]): + def update_data(self, data: dict[str, Tensor[DataType] | Scalar]): if data.keys() & self._all_data.keys(): raise Exception("Some keys are already in data store!") self._all_data |= data diff --git a/mithril/framework/physical/flat_graph.py b/mithril/framework/physical/flat_graph.py index aa0ebf84..d37aa309 100644 --- a/mithril/framework/physical/flat_graph.py +++ b/mithril/framework/physical/flat_graph.py @@ -16,7 +16,6 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass -from functools import partial from ...core import DataType from ..common import ( @@ -77,7 +76,7 @@ class Node: def __hash__(self) -> int: return hash(id(self)) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: return id(self) == id(other) def __repr__(self) -> str: @@ -128,7 +127,7 @@ def all_keys(self): | set(self.output_dict.values()) ) - def add_value(self, model: PrimitiveModel, keys: dict) -> bool: + def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> bool: output_key = keys[PrimitiveModel.output_key] keys = { key: self._temp_connection_info.get(value, value) @@ -249,7 +248,7 @@ def all_source_keys(self) -> set[str]: def _reorder_connections(self): queue = list(self._input_keys) - visited_keys = [] + visited_keys: list[str] = [] while queue: key = queue.pop() @@ -277,7 +276,7 @@ def _reorder_connections(self): for key in self._input_keys: visited_keys.remove(key) - nodes = {} + nodes: dict[PrimitiveModel, Node] = {} for key in visited_keys: model = self.get_model(key) if model is None: @@ -314,8 +313,8 @@ def _update_all_target_keys(self): } def _update_connection_keys(self, connection: Connection): - source_keys = [] - target_keys = [] + source_keys: list[str] = [] + target_keys: list[str] = [] if connection.node is not None: for inner_key, conn in connection.node.connections.items(): @@ -325,7 +324,7 @@ def _update_connection_keys(self, connection: Connection): source_keys.append(key) def get_target_keys(connection: Connection): - target_keys = [] + target_keys: list[str] = [] for conn in connection.connections: target_keys.append(conn.key) @@ -359,7 +358,7 @@ def get_model(self, key) -> PrimitiveModel: return conn.node.model - def get_model_out_key(self, model): + def get_model_out_key(self, model: PrimitiveModel): node = self.nodes.get(model, None) if node is None: return None @@ -374,8 +373,8 @@ def get_model_connections(self, model: PrimitiveModel): def get_connection(self, key: str): return self.connections.get(key, None) - def get_source_keys(self, key: str, include_outputs: bool = False): - source_keys = [] + def get_source_keys(self, key: str, include_outputs: bool = False) -> list[str]: + source_keys: list[str] = [] if key in self.connections: source_keys += self.connections[key].source_keys @@ -387,7 +386,7 @@ def get_source_keys(self, key: str, include_outputs: bool = False): return source_keys - def get_target_keys(self, key: str, include_outputs: bool = False): + def get_target_keys(self, key: str, include_outputs: bool = False) -> list[str]: target_keys = ( list(self.connections[key].target_keys) if key in self.connections else [] ) @@ -403,10 +402,10 @@ def get_target_keys(self, key: str, include_outputs: bool = False): def prune_duplicate_nodes( self, - data: dict[str, Tensor | Scalar], + data: dict[str, Tensor[DataType] | Scalar], constant_keys: Mapping[str, DataType | MainValueType], - ): - pruned_keys = {} + ) -> dict[str, str]: + pruned_keys: dict[str, str] = {} for node in list(self.nodes.values()): conn = self._is_duplicate(node, data, constant_keys) if conn is None: @@ -421,14 +420,14 @@ def prune_duplicate_nodes( def _is_duplicate( self, node: Node, - data: dict[str, Tensor | Scalar], + data: dict[str, Tensor[DataType] | Scalar], constant_keys: Mapping[str, DataType | MainValueType], ): if node.model is None: return # Model id is a unique key for unique operation - model_id = [] + model_id: list[str] = [] for key, conn in node.connections.items(): # We do not consider output and cache keys, when determining model id. if key == "output" or "cache" in key: @@ -445,7 +444,7 @@ def _is_duplicate( if not isinstance(value, ToBeDetermined): for value_key, ref_value in self.value_table.items(): if type(ref_value) is not type(value): - is_equal = False + is_equal: bool = False # Check tensors are equal elif self.is_tensor_type(ref_value) and self.is_tensor_type(value): is_equal = ( @@ -556,7 +555,7 @@ def remove_key(self, key: str): def infer_ignore_step( self, key: str, keys: set[str], queue: set[str], from_source: bool ): - forward_key_fn: Callable | partial + forward_key_fn: Callable[[str, bool], list[str]] if from_source: forward_key_fn = self.get_target_keys backward_key_fn = self.get_source_keys diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 16928078..c696aa27 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -28,7 +28,7 @@ Scalar, Table, Tensor, - Uniadic, + UniadicRecord, Updates, Variadic, _get_shapes, @@ -73,7 +73,7 @@ def __init__( *, discard_keys: StringOrConnectionSetType, data_keys: StringOrConnectionSetType, - constant_keys: PhysicalConstantType, + constant_keys: PhysicalConstantType[DataType], trainable_keys: StringOrConnectionSetType, jacobian_keys: StringOrConnectionSetType, shapes: PhysicalShapeType, @@ -133,8 +133,10 @@ def __init__( self.inference = inference # Initialize flat graph and data store. - self._flat_graph: FlatGraph = FlatGraph(self._input_keys, self._output_keys) - memo: dict[int, Tensor | Scalar] = {} + self._flat_graph: FlatGraph[DataType] = FlatGraph( + self._input_keys, self._output_keys + ) + memo: dict[int, Tensor[DataType] | Scalar] = {} self.data_store: StaticDataStore[DataType] = StaticDataStore( self._flat_graph, backend, inference, model.constraint_solver, memo ) @@ -171,7 +173,11 @@ def __init__( f"{', '.join(str(key) for key in unnamed_data_keys)}" ) - def __call__(self, params: dict | None = None, data: dict | None = None): + def __call__( + self, + params: dict[str, DataType] | None = None, + data: Mapping[str, DataType | MainValueType] | None = None, + ): return self.evaluate(params=params, data=data) def _convert_key(self, model: BaseModel, key: str | Connection) -> str: @@ -194,7 +200,7 @@ def _convert_key(self, model: BaseModel, key: str | Connection) -> str: def _check_overridden_nontrainable_keys( self, model: BaseModel, - constant_keys: PhysicalConstantType, + constant_keys: PhysicalConstantType[DataType], data_keys: StringOrConnectionSetType, ) -> None: for key in constant_keys.keys() | data_keys: @@ -267,10 +273,10 @@ def _validate_keys( def get_shapes( self, model: BaseModel | None = None, - uni_keys=None, - var_keys=None, - symbolic=False, - verbose=False, + uni_keys: dict[UniadicRecord, str] | None = None, + var_keys: dict[Variadic, str] | None = None, + symbolic: bool = False, + verbose: bool = False, ) -> _ShapesType: if model is not None: # Find corresponding data from self.data_store_data_memo. @@ -308,9 +314,9 @@ def flatten_dag( self, model: BaseModel, key_mappings: dict[str, str], - name="", - safe_shapes=True, - memo: dict | None = None, + name: str = "", + safe_shapes: bool = True, + memo: dict[int, Tensor[DataType] | Scalar] | None = None, ): _, reorder_graph = self._flatten_dag( model, key_mappings, name, safe_shapes, memo @@ -322,9 +328,9 @@ def _flatten_dag( self, model: BaseModel, key_mappings: dict[str, str], - name="", - safe_shapes=True, - memo: dict | None = None, + name: str = "", + safe_shapes: bool = True, + memo: dict[int, Tensor[DataType] | Scalar] | None = None, ) -> tuple[dict[str, str], bool]: if memo is None: memo = {} @@ -373,8 +379,8 @@ def _flatten_dag( if isinstance(model, PrimitiveModel): output = PrimitiveModel.output_key - _data_dict: dict[str, Tensor | Scalar] = {} - dag = {} + _data_dict: dict[str, Tensor[DataType] | Scalar] = {} + dag: dict[str, str] = {} for inner_key in model.external_keys: updated_inner_key = key_mappings.get(inner_key, inner_key) dag[inner_key] = updated_inner_key @@ -387,7 +393,9 @@ def _flatten_dag( if self.backend.type == "numpy": cache_name = "_".join([dag[output], model.cache_name]) dag["cache"] = cache_name - cache_value: dict | None = None if self.inference else dict() + cache_value: dict[str, MainValueType] | None = ( + None if self.inference else dict() + ) # Create a Scalar object for caches in manualgrad backend. cache_scalar = Scalar(dict | None, cache_value) self.data_store.update_data({cache_name: cache_scalar}) @@ -416,7 +424,7 @@ def _flatten_dag( m_name = name + "_" + m.__class__.__name__ + "_" + str(idx) source_name = m_name - m_mapping = dict() + m_mapping: dict[str, str] = dict() for key, value in model.dag[m].items(): if (res := key_mappings.get(value.key)) is not None: result = res @@ -534,7 +542,7 @@ def randomize_params( stacklevel=1, ) elif variadic: - shape = [item for item in shape if item != (...,)] + shape = [item for item in shape if item != (...,)] # type: ignore warnings.warn( f"Shape of {key} key automatically set to {shape} since it's " "shape includes variadic type!", @@ -564,7 +572,7 @@ def _pre_compile( ) self.jacobian_keys = jacobian_keys - self.ignore_grad_keys = set() + self.ignore_grad_keys: set[str] = set() for node in self._flat_graph.nodes.values(): conn_data = node.model.conns.get_connection("output") @@ -759,7 +767,7 @@ def infer_ignore( weak_keys: set[str], output_keys: set[str], strict_keys: set[str] | None = None, - update_graph=True, + update_graph: bool = True, ) -> tuple[set[str], set[str]]: """ Infers the keys which will be ignored @@ -823,12 +831,12 @@ def infer_ignore( def _calculate_parameters( self, name_mappings: dict[Model, str], - data_to_key_map: dict[Tensor | Scalar, list[str]] | None = None, + data_to_key_map: dict[Tensor[DataType] | Scalar, list[str]] | None = None, ): total_params: int = 0 - seen_data = set() - exact_param_status = True - param_info: dict[str, tuple[dict, dict]] = {} + seen_data: set[Tensor[DataType]] = set() + exact_param_status: bool = True + param_info: dict[str, tuple[dict[str, str], dict[str, str]]] = {} if data_to_key_map is None: data_to_key_map = {} @@ -894,7 +902,7 @@ def _calculate_parameters( def _print_model_info( self, total_params: str, - data_to_key_map: dict[Tensor | Scalar, list[str]], + data_to_key_map: dict[Tensor[DataType] | Scalar, list[str]], model: BaseModel | None = None, ): # Find constant inputs of the model. @@ -916,14 +924,7 @@ def _print_model_info( if model is not None: # Find all keys of the logical model, Then find the projection of those keys # in their corresponding physical model - projected_keys = set() - # for conn in model.conns.all.values(): - # if ( - # pm_keys := data_to_key_map.get( - # self.data_store.data_memo.get(id(conn.metadata.data)) - # ) - # ) is not None: - # projected_keys.update(pm_keys) + projected_keys: set[str] = set() for conn in model.conns.all.values(): if ( data := self.data_store.data_memo.get(id(conn.metadata.data)) @@ -972,7 +973,7 @@ def summary( print_info: bool = True, name: str | None = None, ): - uni_keys: dict[Uniadic, str] = dict() + uni_keys: dict[UniadicRecord, str] = dict() var_keys: dict[Variadic, str] = dict() if model is None and depth != 0: raise ValueError("Depth cannot be specified when model is not given") @@ -983,7 +984,7 @@ def summary( # If model is not None, create data to key map. this dict will point # determined key names in physical model. - data_to_key_map: dict[Tensor | Scalar, list[str]] = {} + data_to_key_map: dict[Tensor[DataType] | Scalar, list[str]] = {} for key, value in self.data.items(): data_to_key_map.setdefault(value, []).append(key) @@ -1130,8 +1131,8 @@ def extract_connection_info( return conn_info def _replace_with_primitive( - self, model: Model, key_mappings: dict - ) -> tuple[PrimitiveModel, dict]: + self, model: Model, key_mappings: dict[str, str] + ) -> tuple[PrimitiveModel, dict[str, str]]: assert model.formula_key is not None formula = self.backend.primitive_function_dict[model.formula_key] primitive_input_keys = formula.__code__.co_varnames[ diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 63fcd3bd..2998844b 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -27,9 +27,9 @@ class NestedListType: """ __slots__ = "base_type" - base_type: type + base_type: type | UnionType - def __init__(self, base_type): + def __init__(self, base_type: type | UnionType): self.base_type = base_type @@ -73,13 +73,7 @@ def list_shape(ndarray: list[float | int] | float | int) -> list[int]: return [] -def get_unique_types(arg: list | tuple): - # Recursively looks all items in nested sequence - # and returns all unique types in it. - ... - - -def align_shapes(all_dicts: list[dict]) -> None: +def align_shapes(all_dicts: list[dict[Any, Any]]) -> None: """Align all shapes given in the list Examples: @@ -159,9 +153,9 @@ class GeneratedFunction: serialization and deserialization methods. """ - def __init__(self, func: FunctionType, metadata: dict): + def __init__(self, func: FunctionType, metadata: dict[str, str]): self.func = func - self.metadata = metadata + self.metadata: dict[str, str] = metadata def __reduce__(self): # Serialize the function code and metadata @@ -169,11 +163,11 @@ def __reduce__(self): source_code = self.metadata["source"] return (self._unpickle, (source_code, fn_name)) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any): return self.func(*args, **kwargs) @staticmethod - def _unpickle(source_code, fn_name): + def _unpickle(source_code: str, fn_name: str): # Compile the code string back to a code object code = compile(source_code, "", "exec") namespace: dict[str, Any] = {} @@ -191,14 +185,14 @@ def infer_all_possible_types( for sub_type in sub_types: possible_types.update(infer_all_possible_types(sub_type)) elif isinstance(type_def, GenericAlias): - seq_type: type[tuple] | type[list] = type_def.__origin__ + seq_type: type[tuple[Any, ...]] | type[list[Any]] = type_def.__origin__ possible_seq_type: type sub_types = list(type_def.__args__) if seq_type is tuple: possible_seq_type = tuple if len(sub_types) == 2 and sub_types[-1] == ...: for typ in infer_all_possible_types(sub_types[0]): - _types = {possible_seq_type[typ, ...], possible_seq_type[typ]} + _types: Any = {possible_seq_type[typ, ...], possible_seq_type[typ]} possible_types.update(_types) else: type_probs = [infer_all_possible_types(typ) for typ in sub_types] @@ -216,19 +210,19 @@ def infer_all_possible_types( def find_list_base_type( - type_def: type[list] + type_def: type[list[Any]] | type[float] | type[int] | type[bool] | UnionType | NestedListType | GenericAlias, -) -> set[type]: - result = set() +) -> set[type | UnionType]: + result: set[type | UnionType] = set() if isinstance(type_def, NestedListType): result.add(type_def.base_type) elif isinstance(type_def, GenericAlias): - origin: type[list] | type[tuple] = type_def.__origin__ + origin: type[list[Any]] | type[tuple[Any, ...]] = type_def.__origin__ if origin is list: # Means there exists recursive list type. for arg in type_def.__args__: @@ -303,7 +297,7 @@ def find_intersection_type( # Constrain larger union type to the smaller one. base_type_1 = nested_types[0].base_type base_type_2 = nested_types[1].base_type - return NestedListType(find_intersection_type(base_type_1, base_type_2)) + return NestedListType(find_intersection_type(base_type_1, base_type_2)) # type: ignore # First find direct intersections. subtypes_1 = set(type_1.__args__) if type(type_1) is UnionType else {type_1} @@ -347,7 +341,7 @@ def find_intersection_type( # this means one of the types with origin are empty list or tuple, # in that case, take the empty one (tuple[()], or list[()]) as # intersection type - common = typ_1.__origin__[()] # type: ignore + common: Any = typ_1.__origin__[()] # type: ignore elif typ_1.__origin__ is tuple: ellipsis_1 = ... in args_1 @@ -425,7 +419,9 @@ def merge_dicts( return base_dict -def sort_type(type1: type | UnionType | GenericAlias): +def sort_type( + type1: type | UnionType | GenericAlias, +) -> type | UnionType | GenericAlias: """ Returns the sorted type of UnionTypes """ diff --git a/mithril/models/models.py b/mithril/models/models.py index be7ec1f5..a974f985 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -424,7 +424,7 @@ def __init__( | tuple[tuple[int, int], tuple[int, int]] | ToBeDetermined = (0, 0), dilation: int | tuple[int, int] | ToBeDetermined = (1, 1), - use_bias=True, + use_bias: bool = True, ) -> None: super().__init__() @@ -502,7 +502,7 @@ class Linear(Model): w: Connection b: Connection - def __init__(self, dimension: int | None = None, use_bias=True) -> None: + def __init__(self, dimension: int | None = None, use_bias: bool = True) -> None: super().__init__() self.factory_args = {"dimension": dimension, "use_bias": use_bias} dim: int | str = "d_out" if dimension is None else dimension @@ -613,7 +613,9 @@ class LayerNorm(Model): w: Connection b: Connection - def __init__(self, use_scale=True, use_bias=True, eps=1e-5) -> None: + def __init__( + self, use_scale: bool = True, use_bias: bool = True, eps: float = 1e-5 + ) -> None: super().__init__() self.factory_args = {"use_scale": use_scale, "use_bias": use_bias, "eps": eps} @@ -681,7 +683,11 @@ class GroupNorm(Model): output: Connection def __init__( - self, num_groups: int = 32, use_scale=True, use_bias=True, eps=1e-5 + self, + num_groups: int = 32, + use_scale: bool = True, + use_bias: bool = True, + eps: float = 1e-5, ) -> None: super().__init__() @@ -1710,7 +1716,7 @@ def __init__( cell_type: Cell, max_input_sequence_length: int, max_target_sequence_length: int, - teacher_forcing=False, + teacher_forcing: bool = False, ) -> None: super().__init__() @@ -2177,7 +2183,7 @@ class MDS(DistanceEncoder): predicted_coords: Connection output: Connection - def __init__(self, prediction_dim: int, input_type="distances"): + def __init__(self, prediction_dim: int, input_type: str = "distances"): assert input_type in ["distances", "powered_distances", "points"] base_model = MDSCore(exact_distances=(input_type == "distances")) super().__init__(base_model=base_model, input_type=input_type) @@ -2218,8 +2224,8 @@ class TSNE(DistanceEncoder): def __init__( self, prediction_dim: int, - input_type="distances", - preplexity=20.0, + input_type: str = "distances", + preplexity: float = 20.0, calculate_p_joint: bool = False, ): assert input_type in ["distances", "powered_distances", "points"] @@ -2368,7 +2374,7 @@ class GPRLoss(Model): alpha: Connection output: Connection - def __init__(self, robust=False) -> None: + def __init__(self, robust: bool = False) -> None: super().__init__() self.factory_args = {"robust": robust} diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index afce21f6..ba84fdf1 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -119,7 +119,7 @@ class CustomPrimitiveModel(PrimitiveModel): - def __init__(self, formula_key: str, **kwargs) -> None: + def __init__(self, formula_key: str, **kwargs: TensorType | Scalar) -> None: self.factory_args = {"formula_key": formula_key} | kwargs super().__init__(formula_key, **kwargs) @@ -140,9 +140,12 @@ class SupervisedLoss(PrimitiveModel): output: Connection def __init__( - self, formula_key: str, polymorphic_constraint: bool = True, **kwargs + self, + formula_key: str, + polymorphic_constraint: bool = True, + **kwargs: TensorType | Scalar, ) -> None: - default_kwargs = { + default_kwargs: dict[str, TensorType | Scalar] = { "output": TensorType([("Var_1", ...)]), "input": TensorType([("Var_2", ...)]), "target": TensorType([("Var_3", ...)]), @@ -338,7 +341,7 @@ def __call__( # type: ignore[override] } # Check if the given argument set is valid. if self.formula_key == "cross_entropy_with_log_probs": - args = [] + args: list[str] = [] if robust != NOT_GIVEN: args.append("robust") if cutoff != NOT_GIVEN: @@ -601,13 +604,16 @@ class Activation(PrimitiveModel): output: Connection def __init__( - self, formula_key, polymorphic_constraint: bool = False, **kwargs + self, + formula_key: str, + polymorphic_constraint: bool = False, + **kwargs: TensorType | Scalar, ) -> None: # NOTE: Torch and JAX behave different for some activation functions. # For example JAX handles int type inputs for GELU or LeakyRelu while # Torch assumes only float inputs for these activations. Since JAX handles # more general case, default types are written taking this into account. - default_kwargs = dict( + default_kwargs: dict[str, TensorType | Scalar] = dict( input=TensorType([("Var", ...)]), output=TensorType([("Var", ...)], float) ) # Finalize kwargs. @@ -876,7 +882,7 @@ class PrimitiveConvolution2D(PrimitiveModel): output: Connection bias: Connection - def __init__(self, use_bias=True) -> None: + def __init__(self, use_bias: bool = True) -> None: self.factory_args = {"use_bias": use_bias} formula_key = "conv2d_bias" kwargs: dict[str, TensorType | Scalar] = { diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index d8de72db..2860b527 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -23,9 +23,12 @@ ConnectionData, ConnectionType, ExtendInfo, + IOHyperEdge, IOKey, KeyType, Model, + UniadicRecord, + Variadic, _get_shapes, get_summary_shapes, ) @@ -120,7 +123,7 @@ def get_single_output(model: BaseModel) -> Connection: return getattr(model, out_key) @staticmethod - def check_finalized(fn: Callable): + def check_finalized(fn: Callable[..., Any]) -> Callable[..., Any]: """Decorator to check if given TrainModel is finalized or not. Parameters @@ -129,7 +132,7 @@ def check_finalized(fn: Callable): Any of TrainModel modification methods. """ - def check_fn(context: "TrainModel", *args, **kwargs): + def check_fn(context: "TrainModel", *args: Any, **kwargs: Any): if context._is_finalized: raise Exception( "No modifications can be made to a finalized TrainModel!" @@ -145,14 +148,14 @@ def add_loss( reduce_steps: list[BaseModel] | None = None, key_name: str | None = None, coef: float | None = None, - **kwargs, + **kwargs: Any, ) -> None: # If provided key namings does not match with Loss model keys = set(loss_model._input_keys) - loss_model.conns.get_non_diff_keys() if set(kwargs.keys()) != keys: raise KeyError("The provided keys do not match the model's loss.") - outputs_conns_metadata = set() + outputs_conns_metadata: set[IOHyperEdge] = set() if len(self.conns.output_keys) > 0: for key in self.conns.output_keys: if (given_conn := self.conns.get_connection(key)) is None: @@ -174,7 +177,7 @@ def add_loss( is_loss_connected = True if isinstance(value, Connection): conn = value.data - elif isinstance(value, str): + else: if value not in self.conns.all: raise KeyError("Key does not belong to the Model!") else: @@ -210,7 +213,7 @@ def add_loss( # TODO: Currently kwargs contains only input keys of # first (loss) model. # We may want to add output key for the final model's output key. - reduce_inputs = [] + reduce_inputs: list[tuple[Connection, Connection]] = [] for key in kwargs: if key in loss_model.conns.output_keys: raise KeyError("Output of the loss model cannot be defined!") @@ -266,7 +269,7 @@ def add_regularization( coef: float, reg_key: str | Connection | None = None, key_name: str | None = None, - **kwargs, + **kwargs: Any, ): keys = set(model._input_keys) - model.conns.get_non_diff_keys() if set(kwargs.keys()) != keys: @@ -303,7 +306,7 @@ def _add_regularization( coef: float, reg_key: str | Connection | None = None, key_name: str | None = None, - **kwargs, + **kwargs: Any, ) -> None: # TODO: check if reg_key is single # TODO: Maybe use canonical input to decide reg_key!!! @@ -337,14 +340,14 @@ def _add_regularization( } input_keys = {key for key in self._input_keys if "$" not in key} trainable_keys = (input_keys | set(generated_keys.values())) - non_diff_keys - trainables = set() + trainables: set[IOHyperEdge] = set() for key in trainable_keys: if key in self.conns.all: if (t_key := self.conns.get_connection(key)) is None: raise KeyError("Given key does not belong to the Model!") trainables.add(t_key.metadata) - provided_outputs = set() + provided_outputs: set[IOHyperEdge] = set() for value in kwargs.values(): if isinstance(value, ConnectionData): provided_outputs.add(value.metadata) @@ -386,7 +389,7 @@ def add_metric( model: Model, reduce_steps: list[BaseModel] | None = None, key_name: str | None = None, - **kwargs, + **kwargs: Any, ) -> None: # TODO: Somehow we need to imply metric is attached and self model # could not be extended or be used as another model's child model. @@ -428,7 +431,8 @@ def _add_loss_combiner(self): loss_output_key = LossKey if self.reg_coef_map else FinalCost if (num_of_loss_keys := len(self.loss_keys)) > 1: concat_model = Concat(n=num_of_loss_keys, axis=None) - concat_kwargs, idx = {}, 0 + concat_kwargs: dict[Any, Any] = {} + idx = 0 for key in concat_model._input_keys: # if not concat_model.connections[key].metadata.value.is_non_diff: if not concat_model.conns.is_key_non_diff(key): @@ -502,9 +506,9 @@ def summary( types: bool = False, symbolic: bool = False, name: str | None = None, - alternative_shapes=False, - uni_cache: dict | None = None, - var_cache: dict | None = None, + alternative_shapes: bool = False, + uni_cache: dict[UniadicRecord, str] | None = None, + var_cache: dict[Variadic, str] | None = None, depth: int = 0, ): # TODO: Use all the arguments given above: @@ -524,7 +528,7 @@ def summary( if isinstance(self._model, Model): summary_kwargs["depth"] = depth - self._model.summary(**summary_kwargs) + self._model.summary(**summary_kwargs) # type: ignore name_mappings = define_unique_names(self.dag) conn_info = self.extract_connection_info(name_mappings) @@ -546,7 +550,7 @@ def summary( ), ) - shape_info = get_summary_shapes(model_shapes, conn_info) + shape_info = get_summary_shapes(model_shapes, conn_info) # type: ignore if self.loss_keys: # If any loss is attached, extract useful information # about each added loss and print the table @@ -633,7 +637,7 @@ def summary( ["Metric Model", "Keys", "Shapes", "Connections", "Output key"] ) for m_key in self.metric_keys: - m_list = [] + m_list: list[list[str]] = [] m_conn = self.conns.get_connection(m_key) assert m_conn is not None model = self.dependency_map._local_output_dependency_map[m_conn][0] @@ -671,7 +675,7 @@ def _add_geo_mean(self): geo_mappings[reg_info].append(self.reduce_inputs[key]) for reg_info, loss_connections in geo_mappings.items(): - final_outputs = [] + final_outputs: list[Connection | int] = [] for reduce in loss_connections: final_outputs.append(self._add_reduce_sizes(reduce)) if final_outputs: @@ -707,9 +711,9 @@ def _add_geo_mean(self): assert out_con is not None self.reg_coef_map[coef].add(out_con.conn) - def _add_reduce_sizes(self, reduce_list): + def _add_reduce_sizes(self, reduce_list: list[tuple[Connection, Connection]]): final_output: Connection | int = 1 - sizes = [] + sizes: list[Connection] = [] for input, dim in reduce_list: m = _create_size() self.extend(m, input=input, dim=dim) diff --git a/mithril/utils/func_utils.py b/mithril/utils/func_utils.py index 24f0d48b..7a867222 100644 --- a/mithril/utils/func_utils.py +++ b/mithril/utils/func_utils.py @@ -14,6 +14,7 @@ from collections.abc import Callable from copy import deepcopy +from typing import Any from ..framework.common import Scalar, Tensor @@ -21,7 +22,7 @@ def prepare_function_args( - data: dict[str, Tensor | Scalar], + data: dict[str, Tensor[Any] | Scalar], function: Callable, inputs: key_map_type, array_creation_funcs: list[str], @@ -73,7 +74,7 @@ def prepare_function_args( def create_kwarg_dict( - data: dict[str, Tensor | Scalar], + data: dict[str, Tensor[Any] | Scalar], kwarg_keys: list[str], function: Callable, inputs: key_map_type, @@ -147,7 +148,7 @@ def reorganize_args( return organized_arguments -def is_make_array_required(data: Tensor | Scalar): +def is_make_array_required(data: Tensor[Any] | Scalar): if isinstance(data, Tensor): _temp_shape = next(iter(data.shape.reprs)) # It is needed to guarantee that Tensor is at least one dimensional. diff --git a/mithril/utils/type_utils.py b/mithril/utils/type_utils.py index 19a68e20..a1c975e1 100644 --- a/mithril/utils/type_utils.py +++ b/mithril/utils/type_utils.py @@ -14,7 +14,7 @@ from __future__ import annotations from types import EllipsisType -from typing import TypeGuard +from typing import Any, TypeGuard def is_int_tuple_tuple( @@ -23,28 +23,28 @@ def is_int_tuple_tuple( return isinstance(data[0], tuple) -def is_tuple_int(t) -> TypeGuard[tuple[int, ...]]: +def is_tuple_int(t: Any) -> TypeGuard[tuple[int, ...]]: return isinstance(t, tuple) and all(isinstance(i, int) for i in t) -def is_list_int(t) -> TypeGuard[list[int]]: +def is_list_int(t: Any) -> TypeGuard[list[int]]: return isinstance(t, list) and all(isinstance(i, int) for i in t) -def is_list_str(t) -> TypeGuard[list[str]]: +def is_list_str(t: Any) -> TypeGuard[list[str]]: return isinstance(t, list) and all(isinstance(i, str) for i in t) -def is_list_int_or_none(t) -> TypeGuard[list[int | None]]: +def is_list_int_or_none(t: Any) -> TypeGuard[list[int | None]]: return isinstance(t, list) and all(isinstance(i, int | None) for i in t) -def is_tuple_int_or_none(t) -> TypeGuard[tuple[int | None, ...]]: +def is_tuple_int_or_none(t: Any) -> TypeGuard[tuple[int | None, ...]]: return isinstance(t, tuple) and all(isinstance(i, int | None) for i in t) def is_axis_reduce_type( - axis, + axis: Any, ) -> TypeGuard[int | tuple[int, ...] | None]: is_int = isinstance(axis, int) is_int_tuple = is_tuple_int(axis) @@ -53,7 +53,7 @@ def is_axis_reduce_type( def is_axis_reverse_type( - axis, + axis: Any, ) -> TypeGuard[list[int] | tuple[int, ...] | None]: is_list = is_list_int(axis) is_tuple = is_tuple_int(axis) @@ -61,7 +61,7 @@ def is_axis_reverse_type( return is_list or is_none or is_tuple -def is_tuple_of_two_ints(obj) -> TypeGuard[tuple[int, int]]: +def is_tuple_of_two_ints(obj: Any) -> TypeGuard[tuple[int, int]]: return ( isinstance(obj, tuple) and len(obj) == 2 @@ -70,7 +70,7 @@ def is_tuple_of_two_ints(obj) -> TypeGuard[tuple[int, int]]: def is_padding_type( - padding, + padding: Any, ) -> TypeGuard[tuple[tuple[int, int], tuple[int, int]] | tuple[int, int]]: is_padding = False if isinstance(padding, tuple) and len(padding) == 2: @@ -83,7 +83,7 @@ def is_padding_type( def is_index_type( - index, + index: Any, ) -> TypeGuard[tuple[int | slice | EllipsisType | None, ...]]: return isinstance(index, tuple) and all( isinstance(i, int | slice | EllipsisType | None) for i in index diff --git a/tests/scripts/test_compile_keys_consistencies.py b/tests/scripts/test_compile_keys_consistencies.py index 9274f558..102f73cb 100644 --- a/tests/scripts/test_compile_keys_consistencies.py +++ b/tests/scripts/test_compile_keys_consistencies.py @@ -26,7 +26,7 @@ def test_dollar_sign_str(): Tries all possible compile keys. """ model = Model() - model += Linear(1, 1) + model += Linear(1, True) backend = TorchBackend() kwargs: dict @@ -63,7 +63,7 @@ def test_connection_not_found(): Tries all possible compile keys. """ model = Model() - lin_model = Linear(1, 1) + lin_model = Linear(1, True) mult_model = Multiply() model += lin_model @@ -96,7 +96,7 @@ def test_string_not_found(): Tries all possible compile keys. """ model = Model() - lin_model = Linear(1, 1) + lin_model = Linear(1, True) model += lin_model backend = TorchBackend() @@ -130,7 +130,7 @@ def test_reset_static_data(): Tests for constant_keys and data_keys. """ model = Model() - model += Linear(1, 1)(input=IOKey(name="input", value=[[2.0]])) + model += Linear(1, True)(input=IOKey(name="input", value=[[2.0]])) backend = TorchBackend() kwargs: dict @@ -152,7 +152,7 @@ def test_reset_static_data_2(): Tests for constant_keys and data_keys for connection type keys. """ model = Model() - model += Linear(1, 1)(input=IOKey(name="input", value=[[2.0]])) + model += Linear(1, True)(input=IOKey(name="input", value=[[2.0]])) backend = TorchBackend() kwargs: dict @@ -175,7 +175,7 @@ def test_check_keys_disjoint_sets(): must be disjoint sets. """ model = Model() - model += (lin_model := Linear(1, 1))("input") + model += (lin_model := Linear(1, True))("input") backend = TorchBackend() with pytest.raises(ValueError) as err_info: @@ -207,7 +207,7 @@ def test_static_keys_inputs_only(): other than the inputs of the model. """ model = Model() - model += (lin_model := Linear(1, 1))(input="input", output="lin_out") + model += (lin_model := Linear(1, True))(input="input", output="lin_out") model += Multiply()(output=IOKey(name="output")) backend = TorchBackend() @@ -225,7 +225,7 @@ def test_trainable_keys_inputs_only(): other than the inputs of the model. """ model = Model() - model += (lin_model := Linear(1, 1))(input="input", output="lin_out") + model += (lin_model := Linear(1, True))(input="input", output="lin_out") model += Multiply()(output=IOKey(name="output")) backend = TorchBackend() @@ -243,7 +243,7 @@ def test_discard_keys_input_and_outputs_only(): other than the inputs and outputs of the model. """ model = Model() - model += (lin_model := Linear(1, 1))(input="input", output="lin_out") + model += (lin_model := Linear(1, True))(input="input", output="lin_out") model += Multiply()(output=IOKey(name="output")) backend = TorchBackend() @@ -261,7 +261,7 @@ def test_jacobian_keys_inputs_only(): other than the inputs of the model. """ model = Model() - model += (lin_model := Linear(1, 1))(input="input", output="lin_out") + model += (lin_model := Linear(1, True))(input="input", output="lin_out") model += Multiply()(output=IOKey(name="output")) backend = TorchBackend() @@ -281,7 +281,7 @@ def test_iterable_type_keys(): only non_trainable key to a trainable key. """ model = Model() - model += Linear(1, 1)("input") + model += Linear(1, True)("input") backend = TorchBackend() for typ in [list, tuple, set, dict]: diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 0306c643..9f411d57 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -993,7 +993,7 @@ def test_find_dominant_type_16(): def test_sort_type_1(): input = int new_type = sort_type(input) - assert new_type.__name__ == "int" + assert new_type.__name__ == "int" # type: ignore def test_sort_type_2(): From f43801bcefc5e1501488084475c095b3f40fd2c6 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 28 Nov 2024 13:55:28 +0300 Subject: [PATCH 07/26] fix: reviews are applied --- mithril/backends/backend.py | 2 +- mithril/framework/codegen/c_gen.py | 3 +-- mithril/framework/common.py | 16 ++++++++-------- mithril/framework/logical/model.py | 2 +- mithril/framework/logical/primitive.py | 10 ---------- mithril/framework/utils.py | 5 ++++- 6 files changed, 15 insertions(+), 23 deletions(-) diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index 6d75399d..5ca12be5 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -38,7 +38,7 @@ class Backend(ABC, Generic[DataType]): is_installed = True _device: str _precision: int - primitive_function_dict: dict[str, Callable[..., DataType]] + primitive_function_dict: dict[str, Callable[..., DataType | Any]] registered_primitives: dict[str, Callable[..., DataType]] array_creation_funcs: list[str] primitive_fn_path: str diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index 643ebeb8..d935e05f 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -17,7 +17,6 @@ import subprocess import tempfile from functools import partial -from typing import Any from ...backends.with_manualgrad.c_backend import CBackend, backend from ...backends.with_manualgrad.c_backend.src import array @@ -34,7 +33,7 @@ class CGen(CodeGen): BACKWARD_FN_SUFFIX = "_grad" - def __init__(self, pm: PhysicalModel[Any]) -> None: + def __init__(self, pm: PhysicalModel[PyArray]) -> None: super().__init__(pm) assert isinstance(self.pm.backend, CBackend) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index b189e51a..817c7825 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -2700,14 +2700,14 @@ def __hash__(self) -> int: @overload -def add_lens(x: int, y: int, const: Iterator[int]) -> int: ... +def add_lengths(x: int, y: int, const: Iterator[int]) -> int: ... @overload -def add_lens(x: str, y: str, const: Iterator[str]) -> str: ... +def add_lengths(x: str, y: str, const: Iterator[str]) -> str: ... -def add_lens(x: str | int, y: str | int, const: Iterator[str | int]) -> str | int: +def add_lengths(x: str | int, y: str | int, const: Iterator[str | int]) -> str | int: return x + next(const) + y # type: ignore @@ -2733,11 +2733,10 @@ def add_row(self, row: RowColumnType): self.cells.append(row) def add_column(self, column: RowColumnType): - idx = 0 for idx, row in enumerate(column[: len(self.headers)]): self.headers[idx].append(row) # type: ignore - for idx_, row in enumerate(column[idx + 1 :]): + for idx_, row in enumerate(column[len(self.headers) :]): self.cells[idx_].append(row) # type: ignore def _calculate_table_specs(self): @@ -2849,13 +2848,14 @@ def compile( self._adjust_table() # calculate total table width table_width = reduce( # type: ignore - partial(add_lens, const=(len(row) for row in row_sep)), self.each_row_width + partial(add_lengths, const=(len(row) for row in row_sep)), + self.each_row_width, ) table_constructor_fn: Callable[[str, str], str] = partial( - add_lens, const=cycle(row_sep) + add_lengths, const=cycle(row_sep) ) # type: ignore table_constructor_fn_w_spaces: Callable[[str, str], str] = partial( # type: ignore - add_lens, const=cycle(len(row) * " " for row in row_sep) + add_lengths, const=cycle(len(row) * " " for row in row_sep) ) end = "\n" header_list: list[str] = [] diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index ce007c26..992ae49d 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -1101,7 +1101,7 @@ def extend( # Hold shape information for IOKey type values in order # to set all in a bulk after all connections are added. if value._shape is not None: - shape_info |= {key: value._shape} # + shape_info |= {key: value._shape} if value._type is not None: type_info[key] = value._type diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 46e3e17e..c3b2f851 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -134,16 +134,6 @@ def __iadd__(self, other: BaseModel): f"Primitive '{self.__class__.__name__}' model can not be extended!" ) - @staticmethod - def convert_to_tuple(value: int | tuple[int, int] | list[Any]) -> tuple[int, int]: - if isinstance(value, int): - new_value = (value, value) - elif isinstance(value, list): - new_value = tuple(value) - else: - new_value = value - return new_value - def extract_connection_info( self, name_mappings: dict[BaseModel, str], diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 21fe8208..3994b0b1 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -194,7 +194,10 @@ def infer_all_possible_types( possible_seq_type = tuple if len(sub_types) == 2 and sub_types[-1] == ...: for typ in infer_all_possible_types(sub_types[0]): - _types: Any = {possible_seq_type[typ, ...], possible_seq_type[typ]} + _types: set[Any] = { + possible_seq_type[typ, ...], + possible_seq_type[typ], + } possible_types.update(_types) else: type_probs = [infer_all_possible_types(typ) for typ in sub_types] From d40da6b0c6ccb73aa9de93778d7a313cc1e38022 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 28 Nov 2024 14:46:55 +0300 Subject: [PATCH 08/26] apply reviews --- mithril/framework/codegen/__init__.py | 2 +- mithril/framework/codegen/code_gen.py | 10 ++++++---- mithril/framework/common.py | 24 +++++++++++------------- mithril/framework/physical/model.py | 1 + 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/mithril/framework/codegen/__init__.py b/mithril/framework/codegen/__init__.py index 9f7c3fbf..352a5693 100644 --- a/mithril/framework/codegen/__init__.py +++ b/mithril/framework/codegen/__init__.py @@ -19,7 +19,7 @@ from .code_gen import CodeGen from .python_gen import PythonCodeGen -code_gen_map: dict[type[Backend[Any]], type[CodeGen]] = {} +code_gen_map: dict[type[Backend[Any]], type[CodeGen[Any]]] = {} try: from ...backends.with_autograd.jax_backend import JaxBackend diff --git a/mithril/framework/codegen/code_gen.py b/mithril/framework/codegen/code_gen.py index eccd03e7..90fbf686 100644 --- a/mithril/framework/codegen/code_gen.py +++ b/mithril/framework/codegen/code_gen.py @@ -14,14 +14,16 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Any +from typing import Generic + +from mithril import DataType from ..physical.model import PhysicalModel -class CodeGen(ABC): - def __init__(self, pm: PhysicalModel[Any]) -> None: - self.pm = pm +class CodeGen(ABC, Generic[DataType]): + def __init__(self, pm: PhysicalModel[DataType]) -> None: + self.pm: PhysicalModel[DataType] = pm self.code: str | None = None self.file_path: str | None = None diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 817c7825..dacc4bd6 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -526,7 +526,7 @@ def _get_shapes( symbolic: bool = True, verbose: bool = False, key_mappings: dict[str, str] | None = None, -) -> dict[str, ShapeTemplateType | list[ShapeTemplateType]]: +) -> dict[str, ShapeTemplateType | list[ShapeTemplateType] | None]: if key_mappings is None: key_mappings = {} if uniadic_keys is None: @@ -542,7 +542,7 @@ def _get_shapes( ) else: shapes[key_name] = None - return shapes # type: ignore + return shapes class BaseData: @@ -583,7 +583,7 @@ def finalize_match(self, other: Tensor[Any] | Scalar): Val_Type = TypeVar("Val_Type", MainValueType, DataType) # type: ignore - def set_value(self, value: Val_Type): + def set_value(self, value: Val_Type) -> Updates: raise NotImplementedError("No 'set_value' method implemented.") def set_type(self, type: type[TensorValueType] | ScalarType | UnionType) -> Updates: @@ -633,7 +633,8 @@ def remove_constraint(self, constraint: Constraint): def match_shapes(self, other: BaseData): return Updates() - def match(self, other: Tensor[Any] | Scalar) -> Updates: + def match(self, other: Tensor[DataType] | Scalar) -> Updates: + self._differentiable: bool updates = Updates() if self != other: updates = Updates() @@ -695,7 +696,7 @@ def _set_as_physical(self): super()._set_as_physical() def make_physical( - self, backend: Backend[Any], memo: dict[int, Tensor[Any] | Scalar] + self, backend: Backend[DataType], memo: dict[int, Tensor[DataType] | Scalar] ): physical_tensor = deepcopy(self, memo) # Update data as physical data. @@ -703,7 +704,7 @@ def make_physical( # Update value of physical data taking backend into account. return physical_tensor - def __deepcopy__(self, memo: dict[int, Tensor[Any] | Scalar]): + def __deepcopy__(self, memo: dict[int, Tensor[DataType] | Scalar]): # Check if the object is already in the memo dictionary. if id(self) in memo: return memo[id(self)] @@ -719,7 +720,7 @@ def __deepcopy__(self, memo: dict[int, Tensor[Any] | Scalar]): setattr(new_instance, k, deepcopy(v, memo)) return new_instance - def match_shapes(self, other: Tensor[Any]): # type: ignore[override] + def match_shapes(self, other: Tensor[DataType]): # type: ignore[override] updates = Updates() if other.shape != self.shape: updates |= self.shape.merge(other.shape) @@ -805,12 +806,12 @@ def find_type(self, value: MainValueType | str) -> ScalarType: else: return find_type(value) - def _convert_value(self, backend: Backend[Any]): + def _convert_value(self, backend: Backend[DataType]): self.value = backend.cast(self.value) return self.value def make_physical( - self, backend: Backend[Any], memo: dict[int, Tensor[Any] | Scalar] + self, backend: Backend[DataType], memo: dict[int, Tensor[DataType] | Scalar] ): new_scalar = deepcopy(self, memo) if id(self) not in memo: @@ -1182,7 +1183,7 @@ def __hash__(self) -> int: | Mapping[str, ShapeTemplateType] | Mapping[Connection, ShapeTemplateType] ) -_ShapesType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType]] +_ShapesType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None] @dataclass @@ -1203,9 +1204,6 @@ class ConnectionData: def __hash__(self) -> int: return hash(id(self)) - def __and__(self, other) -> Connect: - return Connect(self.conn) & other - def __eq__(self, other: object) -> bool: return id(self) == id(other) diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index e6a65c53..9c6ca319 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -545,6 +545,7 @@ def randomize_params( # seed_key = self.backend.set_seed_key(seed, seed_key) shape = self.shapes[key] + assert shape is not None shape_len = len(shape) if None in shape: raise Exception( From 4f8345989fad8203a7e84d7f50e256261379172c Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 2 Dec 2024 13:30:53 +0300 Subject: [PATCH 09/26] evaluate return type is updated --- examples/gpt/run_sample.py | 2 +- .../model_api/cnn_forcast_sine_training.py | 10 +- .../variable_length_many_to_one_lstm.py | 4 +- .../with_manualgrad/numpy_backend/backend.py | 4 +- mithril/framework/codegen/c_gen.py | 23 +-- mithril/framework/codegen/code_gen.py | 8 +- mithril/framework/codegen/numpy_gen.py | 33 +-- mithril/framework/codegen/python_gen.py | 6 +- mithril/framework/codegen/utils.py | 3 +- mithril/framework/constraints.py | 2 +- mithril/framework/logical/base.py | 5 +- mithril/framework/logical/model.py | 10 +- mithril/framework/physical/model.py | 84 ++++---- mithril/framework/utils.py | 6 +- mithril/utils/func_utils.py | 2 +- mithril/utils/utils.py | 2 +- .../randomized_model_tests_all_backends.json | 6 +- tests/scripts/helper.py | 6 +- tests/scripts/test_all_models.py | 60 +++--- tests/scripts/test_c_backend.py | 35 +++- tests/scripts/test_constant_inputs.py | 61 ++++-- tests/scripts/test_extend_template.py | 194 +++++++++++------- tests/scripts/test_io_key.py | 42 ++-- tests/scripts/test_jittable.py | 10 +- tests/scripts/test_metrics.py | 58 ++++-- tests/scripts/test_parallel.py | 87 +++++--- tests/scripts/test_recurrent_models.py | 3 +- tests/scripts/test_scripts.py | 118 ++++++++--- tests/scripts/test_train_context.py | 32 +-- tests/scripts/test_type_coercion.py | 35 +++- tests/scripts/test_type_consistencies.py | 11 +- tests/utils.py | 2 +- 32 files changed, 626 insertions(+), 338 deletions(-) diff --git a/examples/gpt/run_sample.py b/examples/gpt/run_sample.py index fc6e416a..47397105 100644 --- a/examples/gpt/run_sample.py +++ b/examples/gpt/run_sample.py @@ -152,7 +152,7 @@ def generate( outputs = model.evaluate(weights, data={"input": idx_cond}) logits = outputs["output"] # Pluck the logits at the final step and scale by desired temperature - logits = logits[:, -1, :] / temperature + logits = logits[:, -1, :] / temperature # type: ignore # Optionally crop the logits to only the top k options if top_k is not None: v = model.backend.topk(logits, min(top_k, logits.shape[-1])) diff --git a/examples/model_api/cnn_forcast_sine_training.py b/examples/model_api/cnn_forcast_sine_training.py index d8c86b7e..4c18717b 100644 --- a/examples/model_api/cnn_forcast_sine_training.py +++ b/examples/model_api/cnn_forcast_sine_training.py @@ -111,8 +111,7 @@ def generate_sine_wave(seq_len, num_samples): data = {"input": inputs, "target": targets} outputs, gradients = pm.evaluate_all(params, data) params, opt_state = optimizer.update_params(params, gradients, opt_state) - - total_loss += outputs["final_cost"] + total_loss += outputs["final_cost"] # type: ignore print(f"Epoch: {epoch} / {num_epochs} -> ", total_loss / len(dataloader)) # Test with single sample. @@ -130,7 +129,12 @@ def generate_sine_wave(seq_len, num_samples): linestyle="", ) plt.plot( - seq_len, pred.reshape(1), label="Predicted", color="red", marker="o", linestyle="" + seq_len, + pred.reshape(1), # type: ignore + label="Predicted", + color="red", + marker="o", + linestyle="", ) plt.legend() plt.show() diff --git a/examples/model_api/variable_length_many_to_one_lstm.py b/examples/model_api/variable_length_many_to_one_lstm.py index 5a6846d3..1671ed61 100644 --- a/examples/model_api/variable_length_many_to_one_lstm.py +++ b/examples/model_api/variable_length_many_to_one_lstm.py @@ -196,7 +196,7 @@ # Unpack time data to single tensors for output and target data. unpacked_output_data = unpack_time_slot_data( backend=backend, - data=outputs, + data=outputs, # type: ignore max_length=inference_max_target_length, max_size=len(test_data), output_dim=output_dim, @@ -204,4 +204,4 @@ ) # Measure test error. -error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() +error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() # type: ignore diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index ac337735..e3d705a3 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -184,8 +184,8 @@ def ones_like(self, input: np.ndarray, *, dtype: Dtype | None = None) -> np.ndar ) def zeros_like( - self, input: np.ndarray, *, dtype: Dtype | None = None - ) -> np.ndarray: + self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None + ) -> np.ndarray[Any, Any]: _dtype: str | None = None if isinstance(dtype, Dtype): _dtype = dtype.name diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index d935e05f..e595711c 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -80,8 +80,6 @@ def compile_code(self, jit: bool = False, compile_flags: list[str] | None = None assert self.file_path is not None, "Code has not been generated yet!" eval_arg_keys = self.func_arg_keys["evaluate"] - if not self.pm.inference: - eval_grad_arg_keys = self.func_arg_keys["evaluate_gradients"] so_file_path = self.file_path.replace(".c", ".so") default_compile_flags = ["cc", self.file_path, "-shared", "-fPIC"] @@ -106,6 +104,7 @@ def compile_code(self, jit: bool = False, compile_flags: list[str] | None = None lib = ctypes.CDLL(so_file_path) lib.evaluate.argtypes = [ctypes.POINTER(Array)] * len(eval_arg_keys) if not self.pm.inference: + eval_grad_arg_keys = self.func_arg_keys["evaluate_gradients"] lib.evaluate_gradients.argtypes = [ctypes.POINTER(Array)] * len( eval_grad_arg_keys ) @@ -113,8 +112,8 @@ def compile_code(self, jit: bool = False, compile_flags: list[str] | None = None # we need backend data types! # include_internals flag is used for get internal values for backpropagation def evaluate_wrapper( - params: dict[str, PyArray], - data: dict[str, PyArray], + params: dict[str, PyArray] | None, + data: dict[str, PyArray] | None, cache: dict[str, PyArray] | None = None, include_internals: bool = False, ) -> dict[str, PyArray]: @@ -194,8 +193,8 @@ def create_primitive_call(self, formula_name: str, args: list[str]) -> c_ast.Exp return c_ast.Call(formula_name, args) def generate_evaluate(self) -> tuple[c_ast.FunctionDef, set[str]]: - fn_body = [] - used_keys = set() + fn_body: list[c_ast.Expr] = [] + used_keys: set[str] = set() unused_keys = self.pm.data_store.unused_keys cached_data_keys = self.pm.data_store.cached_data.keys() @@ -218,7 +217,7 @@ def generate_evaluate(self) -> tuple[c_ast.FunctionDef, set[str]]: used_keys.add(output_key) used_keys |= set(inputs) - arguments = [] + arguments: list[c_ast.Parameter] = [] for used_key in sorted(used_keys): arguments.append(c_ast.Parameter("Array *", used_key)) @@ -226,9 +225,9 @@ def generate_evaluate(self) -> tuple[c_ast.FunctionDef, set[str]]: return evaluate_fn, used_keys - def generate_evaluate_gradients(self): - fn_body = [] - used_keys = set() + def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: + fn_body: list[c_ast.Expr] = [] + used_keys: set[str] = set() all_ignored_keys = ( self.pm.ignore_grad_keys @@ -250,7 +249,7 @@ def generate_evaluate_gradients(self): # Assume all inputs are Array grad_inputs = [input_key + "_grad" for input_key in inputs] for idx in range(len(grad_inputs)): - fn_inputs = ( + fn_inputs: list[str] = ( [output_key + "_grad", c_ast.Constant(idx), output_key] + inputs + grad_inputs @@ -267,7 +266,7 @@ def generate_evaluate_gradients(self): used_keys |= set(inputs) used_keys |= set(grad_inputs) - arguments = [] + arguments: list[c_ast.Parameter] = [] for used_key in sorted(used_keys): arguments.append(c_ast.Parameter("Array *", used_key)) diff --git a/mithril/framework/codegen/code_gen.py b/mithril/framework/codegen/code_gen.py index 90fbf686..6281ce8f 100644 --- a/mithril/framework/codegen/code_gen.py +++ b/mithril/framework/codegen/code_gen.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Generic +from typing import Any, Generic from mithril import DataType @@ -28,9 +28,11 @@ def __init__(self, pm: PhysicalModel[DataType]) -> None: self.file_path: str | None = None @abstractmethod - def generate_code(self, file_path: str | None = None): + def generate_code(self, file_path: str | None = None) -> None: raise NotImplementedError("generate_code is not implemented") @abstractmethod - def compile_code(self, jit: bool) -> tuple[Callable, Callable, Callable]: + def compile_code( + self, jit: bool + ) -> tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: raise NotImplementedError("compile_code is not implemented") diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index 818886c6..80e6ddea 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -16,9 +16,12 @@ import keyword from collections.abc import Callable from functools import partial +from typing import Any import numpy as np +from mithril import DataType + from ...backends.with_manualgrad.numpy_backend import NumpyBackend from ...core import Dtype from ...framework.physical.model import PhysicalModel @@ -39,17 +42,17 @@ from .utils import check_repr_inequality -class NumpyCodeGen(PythonCodeGen): +class NumpyCodeGen(PythonCodeGen[DataType]): BACKWARD_FN_SUFFIX = "_grad" - def __init__(self, pm: PhysicalModel) -> None: + def __init__(self, pm: PhysicalModel[DataType]) -> None: super().__init__(pm) assert isinstance(self.pm.backend, NumpyBackend) self.backend: NumpyBackend = self.pm.backend def generate_functions(self): - functions = [] + functions: list[ast.FunctionDef] = [] functions.append(self.generate_evaluate()) if not self.pm.inference: functions.append(self.generate_evaluate_gradients(self.pm.ignore_grad_keys)) @@ -87,19 +90,25 @@ def generate_imports(self): return imports - def compile_code(self, jit: bool = False): + def compile_code( + self, jit: bool = False + ) -> tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: eval_fn, grad_fn = self.exec_generated_code() # TODO: Not looks good, and looks over complicated! def evaluate_gradients_wrapper_manualgrad( - params: dict[str, np.ndarray], - data: dict[str, np.ndarray | ValueType] | None = None, - output_gradients: dict[str, np.ndarray] | None = None, + params: dict[str, np.ndarray[Any, Any]], + data: dict[str, np.ndarray[Any, Any] | ValueType] | None = None, + output_gradients: dict[str, np.ndarray[Any, Any]] | None = None, *, - grad_fn: Callable, + grad_fn: Callable[..., Any], include_output: bool = False, ) -> ( - dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]] + dict[str, np.ndarray[Any, Any]] + | tuple[ + dict[str, np.ndarray[Any, Any]], + dict[str, np.ndarray[Any, Any] | dict[str, np.ndarray[Any, Any]]], + ] ): # TODO: Consider not unioning batch data (data) into self.data # If evaluate_gradients called directly, first call evaluate. @@ -107,11 +116,11 @@ def evaluate_gradients_wrapper_manualgrad( if data is None: data = {} - output: dict[str, np.ndarray] = eval_fn( + output: dict[str, np.ndarray[Any, Any]] = eval_fn( params=params, data=data, cache=cached_data ) # Initialize gradients as zero with corresponding shapes. - gradients: dict[str, np.ndarray] = {} + gradients: dict[str, np.ndarray[Any, Any]] = {} for key in ( self.pm._flat_graph.all_keys - self.pm.data_store.all_static_keys @@ -377,7 +386,7 @@ def generate_evaluate_gradients( primitive_global_inputs += [ key for key in kwargs.values() if "cache" not in key ] + [local_to_global_dict["cache"]] - primitive_local_inputs = [ + primitive_local_inputs: list[str] = [ global_to_local_dict[key].pop(0) for key in primitive_global_inputs ] diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 0e319088..e85a0fd8 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -37,7 +37,7 @@ class PythonCodeGen[DataType](CodeGen): - def __init__(self, pm: PhysicalModel) -> None: + def __init__(self, pm: PhysicalModel[Any]) -> None: super().__init__(pm) self.module = ast.parse("") @@ -77,7 +77,9 @@ def compile_code(self, jit: bool = False): eval_fn, grad_fn = self.exec_generated_code() return self.post_process_fns(eval_fn, grad_fn, jit) - def exec_generated_code(self) -> tuple[Callable, Callable | None]: + def exec_generated_code( + self, + ) -> tuple[Callable[..., Any], Callable[..., Any] | None]: if self.code is None: raise Exception( "Code is not generated yet! Please call generate_code() first." diff --git a/mithril/framework/codegen/utils.py b/mithril/framework/codegen/utils.py index e0b2370d..1975aeeb 100644 --- a/mithril/framework/codegen/utils.py +++ b/mithril/framework/codegen/utils.py @@ -14,6 +14,7 @@ import ast import keyword +from typing import Any from ...backends.backend import Backend from ..common import ShapeNode @@ -22,7 +23,7 @@ # TODO: This name misleads -def partial_array_creation_func(backend: Backend, formula_key: str) -> ast.stmt: +def partial_array_creation_func(backend: Backend[Any], formula_key: str) -> ast.stmt: kwargs = [ast.keyword(arg="precision", value=ast.Constant(value=backend.precision))] # We don't need device in manulgrad(Numpy) diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index 14150e12..4d805955 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -3573,7 +3573,7 @@ def tuple_converter_constraint(output: Scalar, input: Scalar) -> ConstrainResult def cross_entropy_constraint( - categorical: Scalar, input: Tensor, target: Tensor + categorical: Scalar, input: Tensor[Any], target: Tensor[Any] ) -> ConstrainResultType: assert input._temp_shape is not None, "Input shape of reverse is not set!" assert target._temp_shape is not None, "Target shape of reverse is not set!" diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 17f25fdb..00c249a1 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -35,6 +35,7 @@ MainValueType, NotAvailable, Scalar, + ShapeNode, ShapesType, ShapeTemplateType, ShapeType, @@ -228,14 +229,14 @@ def _set_shapes( **kwargs: ShapeTemplateType, ) -> None: # Initialize assigned shapes dictionary to store assigned shapes. - assigned_shapes = {} + assigned_shapes: dict[str, ShapeTemplateType] = {} if updates is None: updates = Updates() model = self._get_outermost_parent() used_keys: dict[str | int, ShapeType] = {} - shape_nodes = {} + shape_nodes: dict[str | Connection, tuple[ShapeNode, str]] = {} # TODO: Can this be refactored to use a single loop? for key, shape in chain(shapes.items(), kwargs.items()): metadata = self.conns.extract_metadata(key) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 992ae49d..9bd7b933 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -638,11 +638,7 @@ def handle_auto_conversion( # Return NOTGIVEN if IOKey has value, name and expose # attributes as their default values. existing_conn = None - if ( - connection._name is None - and connection._value == NOT_GIVEN - and connection._expose is None - ): + if connection._name is None and connection._value == NOT_GIVEN: return NOT_GIVEN set_type = connection._type @@ -1436,8 +1432,8 @@ def _generate_keys( return key_mappings def get_unique_submodel_names(self) -> dict[BaseModel, str]: - name_mapping = {} - existing_names = set() + name_mapping: dict[BaseModel, str] = {} + existing_names: set[str] = set() model_type_dict: dict[str, list[BaseModel]] = {} # First, assign existing names and track used names. diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 9c6ca319..7e05e561 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -672,41 +672,49 @@ def _pre_compile( raise ValueError("All outputs gradient are ignored.") def generate_functions( - self, eval_fn: Callable, grad_fn: Callable, eval_all_fn: Callable + self, + eval_fn: Callable[ + [dict[str, DataType] | None, Mapping[str, MainValueType | DataType] | None], + Mapping[str, MainValueType | DataType], + ], + grad_fn: Callable[ + [ + dict[str, DataType] | None, + Mapping[str, MainValueType | DataType] | None, + dict[str, DataType] | None, + ], + dict[str, DataType], + ], + eval_all_fn: Callable[ + [ + dict[str, DataType] | None, + Mapping[str, MainValueType | DataType] | None, + dict[str, DataType] | None, + ], + tuple[Mapping[str, MainValueType | DataType], dict[str, DataType]], + ], ) -> None: - """This function compiles Physical Model. Compilation - process is as follows: - 1. Infer ignore keys using infer_ignore_keys function. - 2. Infer shapes using infer_shapes function. - 3. Infer static keys using infer_static_keys function. - 4. Infer ignore_grad keys using infer_ignore_grad_keys - function. Note that this function is only required - for numpy backend. - 5. Generate and jit evaluate function using ast. - 6. Generate and jit evaluate_gradients function using - ast for numpy backend and using auto-grad - functionality for Jax and Torch. - - Parameters - ---------- - shapes : Optional[IOShapeType], optional - _description_, by default None - static_keys : dict[str, dataType] | None, optional - _description_, by default None - ignore_grad_keys : set[str] | None, optional - _description_, by default None - ignore_keys : set[str] | None, optional - _description_, by default None - - Returns - ------- - tuple[Callable, Callable] - _description_ - """ - - self._generated_eval_fn = eval_fn - self._generated_compute_gradients_fn = grad_fn - self._generated_evaluate_all_fn = eval_all_fn + self._generated_eval_fn: Callable[ + [dict[str, DataType] | None, Mapping[str, MainValueType | DataType] | None], + Mapping[str, MainValueType | DataType], + ] = eval_fn + self._generated_compute_gradients_fn: Callable[ + [ + dict[str, DataType] | None, + Mapping[str, MainValueType | DataType] | None, + dict[str, DataType] | None, + ], + dict[str, DataType], + ] = grad_fn + + self._generated_evaluate_all_fn: Callable[ + [ + dict[str, DataType] | None, + Mapping[str, MainValueType | DataType] | None, + dict[str, DataType] | None, + ], + tuple[Mapping[str, MainValueType | DataType], dict[str, DataType]], + ] = eval_all_fn def create_jacobian_fn(self, generated_fn: Callable): # TODO: Fix this method to make it picklable! @@ -1092,7 +1100,7 @@ def extract_connection_info( ): if name_mappings is None: name_mappings = define_unique_names(self._flat_graph.get_models()) - conn_info: dict[str, tuple[dict, dict]] = {} + conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} for model, model_name in name_mappings.items(): conn_info.setdefault(model_name, ({}, {})) @@ -1198,7 +1206,7 @@ def evaluate( self, params: dict[str, DataType] | None = None, data: Mapping[str, DataType | MainValueType] | None = None, - ): + ) -> Mapping[str, MainValueType | DataType]: if ( isinstance(self.backend, ParallelBackend) and self.backend._parallel_manager is not None @@ -1212,7 +1220,7 @@ def evaluate_gradients( params: dict[str, DataType] | None = None, data: Mapping[str, DataType | MainValueType] | None = None, output_gradients: dict[str, DataType] | None = None, - ): + ) -> dict[str, DataType]: if self.inference: raise NotImplementedError( "Inference mode does not support gradients calculation" @@ -1232,7 +1240,7 @@ def evaluate_all( params: dict[str, DataType] | None = None, data: Mapping[str, DataType | MainValueType] | None = None, output_gradients: dict[str, DataType] | None = None, - ): + ) -> tuple[Mapping[str, MainValueType | DataType], dict[str, DataType]]: if self.inference: raise NotImplementedError( "Inferece mode does not support gradients calculation" diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 3994b0b1..76470948 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -393,7 +393,7 @@ def find_intersection_type( return None -def find_type(connection) -> type: +def find_type(connection: Any) -> type[Any]: if isinstance(connection, tuple | list): element_types: list[Any] = [find_type(elem) for elem in connection] if isinstance(connection, tuple): @@ -405,8 +405,8 @@ def find_type(connection) -> type: return type(connection) -def is_union(typ): - if hasattr(typ, "__origin__"): +def is_union(typ: type | UnionType | GenericAlias | NestedListType) -> bool: + if isinstance(typ, GenericAlias): if ... in typ.__args__: return True return any(is_union(subtype) for subtype in typ.__args__) diff --git a/mithril/utils/func_utils.py b/mithril/utils/func_utils.py index 7a867222..4b275e46 100644 --- a/mithril/utils/func_utils.py +++ b/mithril/utils/func_utils.py @@ -23,7 +23,7 @@ def prepare_function_args( data: dict[str, Tensor[Any] | Scalar], - function: Callable, + function: Callable[..., Any], inputs: key_map_type, array_creation_funcs: list[str], reduce_with_defaults: bool = True, diff --git a/mithril/utils/utils.py b/mithril/utils/utils.py index fe3fe0a9..51281c79 100755 --- a/mithril/utils/utils.py +++ b/mithril/utils/utils.py @@ -366,7 +366,7 @@ def find_slot_lengths(stacked_data, lengths): def unpack_time_slot_data( - backend: Backend, + backend: Backend[DataType], data: dict[str, DataType], max_length: int, max_size: int, diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index 34ed6750..d31db327 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -1551,7 +1551,7 @@ "name": "OneToMany", "differentiability_info": {"input": true, "initial_hidden": true}, "regular_args": { - "max_sequence_length": 5, + "max_sequence_length": 6, "cell_type": "RNNCell" } }, @@ -1573,6 +1573,10 @@ }, "target4": { "shapes": [[4, 4], [1,1], [15 , 15]] + }, + + "target5": { + "shapes": [[2, 2], [1,1], [15 , 15]] } }, "iterations": 20 diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index a562fe24..284eb0c1 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -122,7 +122,7 @@ def evaluate_case( numeric_shape_dict = ( {key: value.shape for key, value in inputs.items()} | {key: value.shape for key, value in model_grad.items()} - | {key: value.shape for key, value in outputs.items()} + | {key: value.shape for key, value in outputs.items()} # type: ignore | {key: value.shape for key, value in static_keys.items()} ) if reference_shapes is not None: @@ -155,7 +155,7 @@ def evaluate_case( backend.abs(v - out) < backend.abs(v) * relative_tolerance ) ) - ) and (out.shape == (() if isinstance(v, float) else v.shape)) + ) and (out.shape == (() if isinstance(v, float) else v.shape)) # type: ignore else: raise Exception( f"Output is supposed to return value for the {k} key, but " @@ -260,7 +260,7 @@ def assert_evaluations_equal(model1, model2, backend, static_keys): output_recreated = pm_recreated.evaluate(inputs) assert list(output_base.keys()) == list(output_recreated.keys()) for key in output_base: - assert backend.abs(output_base[key] - output_recreated[key]).all() < 1e-14 + assert backend.abs(output_base[key] - output_recreated[key]).all() < 1e-14 # type: ignore class TensorMock: diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 63226965..18451247 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -175,7 +175,7 @@ def compile_and_compare( < backend.abs(v) * relative_tolerance ) ) - ) and (out.shape == (() if isinstance(v, float) else v.shape)) + ) and (out.shape == (() if isinstance(v, float) else v.shape)) # type: ignore else: if not isinstance(eq := (out == v), bool): eq = eq.all() @@ -2376,7 +2376,7 @@ def test_cast_int16(): model = Cast(dtype=mithril.int16) inp_int = np.array([1, -2, 3], dtype=np.int32) inp_float = np.array([1, -2, 3], dtype=np.float32) - backends: list[Backend] = [ + backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [ TorchBackend(precision=16), TorchBackend(precision=32), TorchBackend(precision=64), @@ -2404,16 +2404,19 @@ def test_cast_int16(): for backend in backends: for static in statics.values(): - static = backend.array(static) + assert isinstance(static, np.ndarray) + backend_static = backend.array(static) pm = mithril.compile( model, - backend, - constant_keys={"input": static}, + backend, # type: ignore + constant_keys={"input": backend_static}, inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + res_out = res["output"] + assert isinstance(res_out, backend.DataType) + assert res_out.dtype == expected_dtypes[backend.type] + np.testing.assert_allclose(res_out, reference_outputs["output"]) def test_cast_int32(): @@ -2456,8 +2459,10 @@ def test_cast_int32(): inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + res_out = res["output"] + assert isinstance(res_out, backend.DataType) # type: ignore + assert res_out.dtype == expected_dtypes[backend.type] + np.testing.assert_allclose(res_out, reference_outputs["output"]) def test_cast_int64(): @@ -2500,15 +2505,15 @@ def test_cast_int64(): inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + assert res["output"].dtype == expected_dtypes[backend.type] # type: ignore + np.testing.assert_allclose(res["output"], reference_outputs["output"]) # type: ignore def test_cast_float16(): model = Cast(dtype=mithril.float16) inp_int = np.array([1, -2, 3], dtype=np.int32) inp_float = np.array([1, -2, 3], dtype=np.float32) - backends: list[Backend] = [ + backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [ TorchBackend(precision=16), TorchBackend(precision=32), TorchBackend(precision=64), @@ -2536,16 +2541,17 @@ def test_cast_float16(): for backend in backends: for static in statics.values(): - static = backend.array(static) + _static = backend.array(static) pm = mithril.compile( model, - backend, - constant_keys={"input": static}, + backend, # type: ignore + constant_keys={"input": _static}, inference=True, ) - res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + res = pm.evaluate()["output"] + assert isinstance(res, backend.DataType) + assert res.dtype == expected_dtypes[backend.type] + np.testing.assert_allclose(res, reference_outputs["output"]) def test_cast_float32(): @@ -2588,8 +2594,10 @@ def test_cast_float32(): inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + res_out = res["output"] + assert isinstance(res_out, backend.DataType) # type: ignore + assert res_out.dtype == expected_dtypes[backend.type] + np.testing.assert_allclose(res_out, reference_outputs["output"]) def test_cast_float64(): @@ -2628,8 +2636,10 @@ def test_cast_float64(): inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + res_out = res["output"] + assert isinstance(res_out, backend.DataType) # type: ignore + assert res_out.dtype == expected_dtypes[backend.type] + np.testing.assert_allclose(res_out, reference_outputs["output"]) def test_cast_bool(): @@ -2672,8 +2682,10 @@ def test_cast_bool(): inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] - np.testing.assert_allclose(res["output"], reference_outputs["output"]) + res_out = res["output"] + assert isinstance(res_out, backend.DataType) # type: ignore + assert res_out.dtype == expected_dtypes[backend.type] + np.testing.assert_allclose(res_out, reference_outputs["output"]) def test_dtype_int16(): diff --git a/tests/scripts/test_c_backend.py b/tests/scripts/test_c_backend.py index 7e60ca0e..0b234b88 100644 --- a/tests/scripts/test_c_backend.py +++ b/tests/scripts/test_c_backend.py @@ -18,6 +18,7 @@ import numpy as np from mithril import CBackend, NumpyBackend, compile +from mithril.backends.with_manualgrad.c_backend.src.array import PyArray from mithril.models import Add, IOKey, Model, Multiply from ..utils import with_temp_file @@ -65,7 +66,11 @@ def test_cbackend_1(): ) for key in np_outputs: - assert np.allclose(c_backend.to_numpy(c_outputs[key]), np_outputs[key]) + out = c_outputs[key] + out_np = np_outputs[key] + assert isinstance(out_np, np.ndarray) + assert isinstance(out, PyArray) + assert np.allclose(c_backend.to_numpy(out), out_np) for key in np_grads: assert np.allclose(c_backend.to_numpy(c_grads[key]), np_grads[key]) @@ -124,7 +129,11 @@ def test_cbackend_2(file_path: str): ) for key in np_outputs: - assert np.allclose(c_backend.to_numpy(c_outputs[key]), np_outputs[key]) + out = c_outputs[key] + out_np = np_outputs[key] + assert isinstance(out, np.ndarray) + assert isinstance(out, PyArray) + assert np.allclose(c_backend.to_numpy(out), out_np) for key in np_grads: assert np.allclose(c_backend.to_numpy(c_grads[key]), np_grads[key]) @@ -184,7 +193,11 @@ def test_cbackend_3(): ) for key in np_outputs: - assert np.allclose(c_backend.to_numpy(c_outputs[key]), np_outputs[key]) + c_out = c_outputs[key] + np_out = np_outputs[key] + assert isinstance(c_out, PyArray) + assert isinstance(np_out, np.ndarray) + assert np.allclose(c_backend.to_numpy(c_out), np_out) for key in np_grads: assert np.allclose(c_backend.to_numpy(c_grads[key]), np_grads[key]) @@ -215,11 +228,11 @@ def test_broadcast_1(): c_right = c_backend.array(right) c_outputs = c_pm.evaluate({"left": c_left, "right": c_right, "mul": c_mul}) + out = c_outputs["output"] + assert isinstance(out, PyArray) - assert c_outputs["output"].shape == (5, 5) - np.testing.assert_allclose( - c_backend.to_numpy(c_outputs["output"]), (left + right) * mul - ) + assert out.shape == (5, 5) + np.testing.assert_allclose(c_backend.to_numpy(out), (left + right) * mul) def test_broadcast_2(): @@ -247,8 +260,8 @@ def test_broadcast_2(): c_right = c_backend.array(right) c_outputs = c_pm.evaluate({"left": c_left, "right": c_right, "mul": c_mul}) + out = c_outputs["output"] + assert isinstance(out, PyArray) - assert c_outputs["output"].shape == (5, 5) - np.testing.assert_allclose( - c_backend.to_numpy(c_outputs["output"]), (left + right) * mul - ) + assert out.shape == (5, 5) + np.testing.assert_allclose(c_backend.to_numpy(out), (left + right) * mul) diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 021a825f..bde2f773 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -136,7 +136,7 @@ def assert_all_backends_device_precision(model: Model): assert get_array_precision(randomized_input, _type) == precision outputs = comp_model.evaluate(randomized_inputs) - initial_outputs = outputs.copy() + initial_outputs = outputs.copy() # type: ignore # Check if outputs have correct device and precision for output in outputs.values(): @@ -144,7 +144,8 @@ def assert_all_backends_device_precision(model: Model): assert get_array_precision(output, _type) == precision grads = comp_model.evaluate_gradients( - output_gradients=outputs, params=randomized_inputs + output_gradients=outputs, # type: ignore + params=randomized_inputs, ) # Check if gradients have correct device and precision @@ -305,11 +306,13 @@ def test_default_given_compile_numpy(): data = {"axis": None} result = compiled_model.evaluate(inputs, data) - output_gradients = {"output": np.ones_like(result["output"])} + out = result["output"] + assert isinstance(out, np.ndarray) + output_gradients = {"output": np.ones_like(out)} compiled_model.evaluate_gradients( params=inputs, data=data, output_gradients=output_gradients ) - np.testing.assert_array_equal(expected_result, result["output"]) + np.testing.assert_array_equal(expected_result, out) def test_default_given_extend_numpy_3(): @@ -334,11 +337,13 @@ def test_default_given_extend_numpy_3(): data = {"input": np_input} result = compiled_model.evaluate(inputs, data) - output_gradients = {"output": np.ones_like(result["output"])} + out = result["output"] + assert isinstance(out, np.ndarray) + output_gradients = {"output": np.ones_like(out)} compiled_model.evaluate_gradients( params=inputs, data=data, output_gradients=output_gradients ) - np.testing.assert_array_equal(expected_result, result["output"]) + np.testing.assert_array_equal(expected_result, out) def test_default_given_extend_numpy_3_set_values(): @@ -363,11 +368,13 @@ def test_default_given_extend_numpy_3_set_values(): data = {"input": np_input} result = compiled_model.evaluate(inputs, data) - output_gradients = {"output": np.ones_like(result["output"])} + out = result["output"] + assert isinstance(out, np.ndarray) + output_gradients = {"output": np.ones_like(out)} compiled_model.evaluate_gradients( params=inputs, data=data, output_gradients=output_gradients ) - np.testing.assert_array_equal(expected_result, result["output"]) + np.testing.assert_array_equal(expected_result, out) def test_constant_given_data_numpy(): @@ -391,11 +398,13 @@ def test_constant_given_data_numpy(): data = {"axis": 0} result = compiled_model.evaluate(inputs, data) - output_gradients = {"output": np.ones_like(result["output"])} + out = result["output"] + assert isinstance(out, np.ndarray) + output_gradients = {"output": np.ones_like(out)} compiled_model.evaluate_gradients( params=inputs, data=data, output_gradients=output_gradients ) - np.testing.assert_array_equal(expected_result, result["output"]) + np.testing.assert_array_equal(expected_result, out) def test_constant_numpy(): @@ -1051,6 +1060,7 @@ def test_bool_tensor_numpy_32(): model += add_1(left=[7.0, 8.0], right=not_1.output, output=IOKey(name="output")) comp_model = mithril.compile(model=model, backend=NumpyBackend()) output = comp_model.evaluate()["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -1065,6 +1075,7 @@ def test_bool_tensor_numpy_32_set_values(): model.set_values({model.input: [False, False]}) # type: ignore comp_model = mithril.compile(model=model, backend=NumpyBackend()) output = comp_model.evaluate()["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -1078,6 +1089,7 @@ def test_bool_tensor_numpy_64(): model += add_1(left=[7.0, 8.0], right=not_1.output, output=IOKey(name="output")) comp_model = mithril.compile(model=model, backend=NumpyBackend(precision=64)) output = comp_model.evaluate()["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float64 @@ -1090,9 +1102,11 @@ def test_bool_tensor_torch_32(): model += not_1(input=IOKey(value=[False, False], name="input")) model += add_1(left=[7.0, 8.0], right=not_1.output, output=IOKey(name="output")) comp_model = mithril.compile(model=model, backend=TorchBackend(precision=32)) - output = comp_model.evaluate()["output"].numpy() - np.testing.assert_allclose(output, ref) - assert output.dtype == np.float32 + output = comp_model.evaluate()["output"] + assert isinstance(output, torch.Tensor) + out = output.numpy() + np.testing.assert_allclose(out, ref) + assert out.dtype == np.float32 def test_bool_tensor_torch_64(): @@ -1103,9 +1117,11 @@ def test_bool_tensor_torch_64(): model += not_1(input=IOKey(value=[False, False], name="input")) model += add_1(left=[7.0, 8.0], right=not_1.output, output=IOKey(name="output")) comp_model = mithril.compile(model=model, backend=TorchBackend(precision=64)) - output = comp_model.evaluate()["output"].numpy() - np.testing.assert_allclose(output, ref) - assert output.dtype == np.float64 + output = comp_model.evaluate()["output"] + assert isinstance(output, torch.Tensor) + out = output.numpy() + np.testing.assert_allclose(out, ref) + assert out.dtype == np.float64 def test_bool_tensor_jax_32(): @@ -1177,6 +1193,7 @@ def test_static_input_1(): "right": np.array(3.0, dtype=np.float32), } )["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -1214,6 +1231,7 @@ def test_static_input_2(): ) output = comp_model.evaluate()["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -1251,6 +1269,7 @@ def test_static_input_3(): ) output = comp_model.evaluate()["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -1271,6 +1290,7 @@ def test_static_input_4(): "in2": np.array(3.0, dtype=np.float32), } )["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -1293,6 +1313,7 @@ def test_static_input_5(): ) output = comp_model.evaluate()["output"] + assert isinstance(output, np.ndarray) np.testing.assert_allclose(output, ref) assert output.dtype == np.float32 @@ -2476,7 +2497,9 @@ def test_maxpool_1d_padding_type_input(): out_1 = pm.evaluate( data={"input": backend.array([[[10.0, 11.0, 12.0, 13.0, 14.0]]])} ) - assert (out_1["output"] == backend.array([[[11.0, 13.0]]])).all() + out = out_1["output"] + assert isinstance(out, torch.Tensor) + assert (out == backend.array([[[11.0, 13.0]]])).all() def test_maxpool_1d_padding_input_in_evaluate(): @@ -2494,7 +2517,9 @@ def test_maxpool_1d_padding_input_in_evaluate(): "padding": PaddingType.VALID, } ) - assert (out_1["output"] == backend.array([[[11.0, 13.0]]])).all() + out = out_1["output"] + assert isinstance(out, torch.Tensor) + assert (out == backend.array([[[11.0, 13.0]]])).all() def test_maxpool_1d_padding_input_solved_in_constraint(): diff --git a/tests/scripts/test_extend_template.py b/tests/scripts/test_extend_template.py index cd553089..c5edf5e2 100644 --- a/tests/scripts/test_extend_template.py +++ b/tests/scripts/test_extend_template.py @@ -501,9 +501,9 @@ def test_div(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([0.5, -1, 1.5, 0, -2.5, 3]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([0.5, -1, 1.5, 0, -2.5, 3]), out, 1e-6) def test_rdiv(): @@ -522,9 +522,9 @@ def test_rdiv(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([2, -1, 2 / 3, 2, -0.4, 1 / 3]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([2, -1, 2 / 3, 2, -0.4, 1 / 3]), out, 1e-6) def test_floor_div(): @@ -543,9 +543,9 @@ def test_floor_div(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([0.0, -1, 1.0, 0, -3.0, 3]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([0.0, -1, 1.0, 0, -3.0, 3]), out, 1e-6) def test_rfloor_div(): @@ -562,11 +562,10 @@ def test_rfloor_div(): model2 += (div := FloorDivide())(numerator=2, denominator="input") model2 += Buffer()(input=div.output, output=IOKey(name="output")) compare_models(model1, model2, backend, data) - pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([2.0, -1, 0, 2, -1, 0]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([2.0, -1, 0, 2, -1, 0]), out, 1e-6) def test_pow(): @@ -585,9 +584,9 @@ def test_pow(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([1, 4, 9, 0, 25, 36]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([1, 4, 9, 0, 25, 36]), out, 1e-6) def test_rpow(): @@ -606,9 +605,9 @@ def test_rpow(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([2, 1 / 4, 8, 1, 1 / 32, 64]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([2, 1 / 4, 8, 1, 1 / 32, 64]), out, 1e-6) def test_absolute(): @@ -627,7 +626,9 @@ def test_absolute(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - assert (backend.array([1.0, 2, 3, 0, 5, 6]) == pm.evaluate()["output"]).all() + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + assert (backend.array([1.0, 2, 3, 0, 5, 6]) == out).all() def test_mean(): @@ -646,7 +647,9 @@ def test_mean(): compare_models(model1, model2, backend, data, check_internals=False) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose(backend.array(1 / 2), pm.evaluate()["output"], 1e-6) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array(1 / 2), out, 1e-6) def test_max(): @@ -665,7 +668,9 @@ def test_max(): compare_models(model1, model2, backend, data, check_internals=False) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose(backend.array(6), pm.evaluate()["output"], 1e-6) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array(6), out, 1e-6) def test_sum(): @@ -684,7 +689,9 @@ def test_sum(): compare_models(model1, model2, backend, data, check_internals=False) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose(backend.array(3.0), pm.evaluate()["output"], 1e-6) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array(3.0), out, 1e-6) def test_min(): @@ -703,7 +710,9 @@ def test_min(): compare_models(model1, model2, backend, data, check_internals=False) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose(backend.array(-5), pm.evaluate()["output"], 1e-6) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array(-5), out, 1e-6) def test_prod(): @@ -722,7 +731,9 @@ def test_prod(): compare_models(model1, model2, backend, data, check_internals=False) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose(backend.array(90), pm.evaluate()["output"], 1e-6) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array(90), out, 1e-6) def test_variance(): @@ -741,9 +752,9 @@ def test_variance(): compare_models(model1, model2, backend, data, check_internals=False) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array(12.201388888888888), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array(12.201388888888888), out, 1e-6) def test_greater_than(): @@ -769,9 +780,11 @@ def test_greater_than(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([False, False, True, False, True, False]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -799,9 +812,11 @@ def test_greater_equal(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([False, True, True, False, True, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -829,9 +844,11 @@ def test_less_than(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([True, False, False, True, False, False]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -859,9 +876,11 @@ def test_less_equal(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([True, True, False, True, False, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -889,9 +908,11 @@ def test_equal(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([False, True, False, False, False, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -919,9 +940,11 @@ def test_not_equal(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([True, False, True, True, True, False]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -950,9 +973,11 @@ def test_not(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([False, True, False, False, False, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -982,9 +1007,11 @@ def test_and(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([False, False, False, True, False, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -1014,9 +1041,11 @@ def test_or(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([True, False, True, True, False, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -1046,9 +1075,11 @@ def test_xor(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([True, False, True, False, False, False]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -1079,9 +1110,11 @@ def test_xor2(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) np.testing.assert_allclose( backend.array([True, True, True, True, False, True]), - pm.evaluate()["output"], + out, 1e-6, ) @@ -1109,9 +1142,9 @@ def test_lshift_1(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) - np.testing.assert_allclose( - backend.array([2, -4, 12, 40, -10, 12]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([2, -4, 12, 40, -10, 12]), out, 1e-6) def test_lshift_2(): @@ -1134,9 +1167,9 @@ def test_lshift_2(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) - np.testing.assert_allclose( - backend.array([4, -8, 12, 20, -20, 24]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([4, -8, 12, 20, -20, 24]), out, 1e-6) def test_lshift_3(): @@ -1157,9 +1190,9 @@ def test_lshift_3(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) - np.testing.assert_allclose( - backend.array([4, 0, 16, 64, 0, 128]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([4, 0, 16, 64, 0, 128]), out, 1e-6) def test_rshift_1(): @@ -1185,9 +1218,9 @@ def test_rshift_1(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) - np.testing.assert_allclose( - backend.array([0, -1, 0, 0, -3, 3]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([0, -1, 0, 0, -3, 3]), out, 1e-6) def test_rshift_2(): @@ -1210,9 +1243,9 @@ def test_rshift_2(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) - np.testing.assert_allclose( - backend.array([0, -1, 0, 1, -2, 1]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([0, -1, 0, 1, -2, 1]), out, 1e-6) def test_rshift_3(): @@ -1233,9 +1266,9 @@ def test_rshift_3(): pm = mithril.compile( model=model1, backend=backend, constant_keys=data, inference=True ) - np.testing.assert_allclose( - backend.array([1, 0, 0, 0, 0, 2]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([1, 0, 0, 0, 0, 2]), out, 1e-6) def test_minus(): @@ -1256,9 +1289,9 @@ def test_minus(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([-1.0, 2, -3, -0.5, 5, -6]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([-1.0, 2, -3, -0.5, 5, -6]), out, 1e-6) def test_use_submodel_conn_1(): @@ -1287,9 +1320,9 @@ def test_use_submodel_conn_1(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([5.0, 3.5, 6, 4.75, 2, 7.5]), pm.evaluate()["output"], 1e-6 - ) + out = pm.evaluate()["output"] + assert isinstance(out, jnp.ndarray) + np.testing.assert_allclose(backend.array([5.0, 3.5, 6, 4.75, 2, 7.5]), out, 1e-6) def test_use_multiple_times(): @@ -1312,12 +1345,13 @@ def test_use_multiple_times(): compare_models(model1, model2, backend, data) pm = mithril.compile(model=model1, backend=backend, constant_keys=data) - np.testing.assert_allclose( - backend.array([2, 0.5, 3, 1.75, -1, 4.5]), pm.evaluate()["output1"], 1e-6 - ) - np.testing.assert_allclose( - backend.array([2, 0.5, 3, 1.75, 0, 4.5]), pm.evaluate()["output2"], 1e-6 - ) + out1 = pm.evaluate()["output1"] + out2 = pm.evaluate()["output2"] + assert isinstance(out1, jnp.ndarray) + assert isinstance(out2, jnp.ndarray) + + np.testing.assert_allclose(backend.array([2, 0.5, 3, 1.75, -1, 4.5]), out1, 1e-6) + np.testing.assert_allclose(backend.array([2, 0.5, 3, 1.75, 0, 4.5]), out2, 1e-6) def test_invalid_input(): @@ -1452,7 +1486,9 @@ def test_tensoritem_multiple_slice_3(): ) outputs = pm.evaluate() - assert outputs["output"].shape == (1, 6) + out = outputs["output"] + assert isinstance(out, jnp.ndarray) + assert out.shape == (1, 6) def test_tensor_item_with_ellipsis_at_beginning(): @@ -1465,6 +1501,7 @@ def test_tensor_item_with_ellipsis_at_beginning(): pm = mithril.compile(model, backend=backend) output = pm.evaluate(data)["output"] + assert isinstance(output, jnp.ndarray) assert output.shape == (3, 4) np.testing.assert_allclose(output, data["input"][..., 3]) @@ -1480,6 +1517,7 @@ def test_tensor_item_with_ellipsis_in_middle(): pm = mithril.compile(model, backend=backend) output = pm.evaluate(data)["output"] + assert isinstance(output, jnp.ndarray) assert output.shape == (3, 4, 5) np.testing.assert_allclose(output, data["input"][0, ..., 3]) @@ -1509,8 +1547,10 @@ def test_tranpose_2(): pm = mithril.compile(model, backend=backend) outputs = pm.evaluate({"input": backend.ones(16, 4, 8)}) + out = outputs["output"] + assert isinstance(out, jnp.ndarray) - assert (backend.transpose(backend.ones(16, 4, 8)) == outputs["output"]).all() + assert (backend.transpose(backend.ones(16, 4, 8)) == out).all() def test_tranpose_3(): @@ -1526,8 +1566,10 @@ def test_tranpose_3(): pm = mithril.compile(model, backend=backend) outputs = pm.evaluate({"input": input_arr}) + out = outputs["output"] + assert isinstance(out, jnp.ndarray) - assert (backend.transpose(input_arr, axis) == outputs["output"]).all() + assert (backend.transpose(input_arr, axis) == out).all() def test_tranpose_4(): @@ -1543,8 +1585,10 @@ def test_tranpose_4(): pm = mithril.compile(model, backend=backend) outputs = pm.evaluate({"input": input_arr}) + out = outputs["output"] + assert isinstance(out, jnp.ndarray) - assert (backend.transpose(input_arr, axis) == outputs["output"]).all() + assert (backend.transpose(input_arr, axis) == out).all() def test_split_direct(): @@ -1559,8 +1603,10 @@ def test_split_direct(): pm = mithril.compile(model, backend) outputs = pm.evaluate({"input": input_arr}) + out = outputs["output"] + assert isinstance(out, jnp.ndarray) - assert (jnp.stack(jnp.split(input_arr, 2, axis=1)) == outputs["output"]).all() + assert (jnp.stack(jnp.split(input_arr, 2, axis=1)) == out).all() def test_split_compare_with_explicit(): diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index 59bca064..f2f69620 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import torch import mithril from mithril import TorchBackend @@ -294,8 +295,10 @@ def test_9(): pm = mithril.compile(model=model, backend=backend, jit=False) res = pm.evaluate(params={"input": backend.ones(5, 5)}) + out1 = res["output"] + assert isinstance(out1, torch.Tensor) np.testing.assert_array_equal( - res["output"], backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) + out1, backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) ) @@ -309,10 +312,12 @@ def test_10(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) res = pm.evaluate(params={"input": backend.ones(5, 5)}) + out = res["output"] + assert isinstance(out, torch.Tensor) assert res.keys() == {"output", "middle"} np.testing.assert_array_equal( - res["output"], backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) + out, backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) ) @@ -326,8 +331,10 @@ def test_11(): pm = mithril.compile(model=model, backend=backend, jit=False) res = pm.evaluate(params={"input": backend.ones(5, 5)}) + out = res["output"] + assert isinstance(out, torch.Tensor) np.testing.assert_array_equal( - res["output"], backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) + out, backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) ) @@ -340,10 +347,12 @@ def test_12(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) res = pm.evaluate(params={"input": backend.ones(5, 5)}) + out = res["output"] + assert isinstance(out, torch.Tensor) assert res.keys() == {"output"} np.testing.assert_array_equal( - res["output"], backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) + out, backend.array(backend.sigmoid(backend.relu(backend.ones(5, 5)))) ) @@ -358,14 +367,16 @@ def test_13(): backend = TorchBackend() pm = mithril.compile(model=model, backend=backend, jit=False) res = pm.evaluate(params={"input": backend.ones(5, 5)}) + out1 = res["output1"] + assert isinstance(out1, torch.Tensor) + out2 = res["output2"] + assert isinstance(out2, torch.Tensor) assert res.keys() == {"output1", "output2"} np.testing.assert_array_equal( - res["output1"], backend.array(backend.sigmoid(backend.ones(5, 5))) - ) - np.testing.assert_array_equal( - res["output2"], backend.array(backend.relu(backend.ones(5, 5))) + out1, backend.array(backend.sigmoid(backend.ones(5, 5))) ) + np.testing.assert_array_equal(out2, backend.array(backend.relu(backend.ones(5, 5)))) def test_iokey_shapes_1(): @@ -605,11 +616,12 @@ def test_iokey_values_10(): results = pm.evaluate() expected_result = np.array(backend.sigmoid(backend.array([1.0, 2.0]))) - - np.testing.assert_allclose(results["output"], expected_result, rtol=1e-6, atol=1e-6) - np.testing.assert_allclose( - results["output2"], expected_result, rtol=1e-6, atol=1e-6 - ) + out = results["output"] + out2 = results["output2"] + assert isinstance(out, torch.Tensor) + assert isinstance(out2, torch.Tensor) + np.testing.assert_allclose(out, expected_result, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(out2, expected_result, rtol=1e-6, atol=1e-6) def test_iokey_values_11(): @@ -1269,7 +1281,9 @@ def test_iokey_template_6(): pm._output_keys = {"output"} res = pm.evaluate(params={"input": backend.ones((3, 4, 5))}) - np.testing.assert_almost_equal(res["output"], np.ones((4, 5))) + out = res["output"] + assert isinstance(out, torch.Tensor) + np.testing.assert_almost_equal(out, np.ones((4, 5))) def test_iokey_template_7(): diff --git a/tests/scripts/test_jittable.py b/tests/scripts/test_jittable.py index 800ffe51..57bc5c1b 100644 --- a/tests/scripts/test_jittable.py +++ b/tests/scripts/test_jittable.py @@ -171,7 +171,9 @@ def test_mymodel_jax_1(): ) inputs = compiled_model.randomize_params() result = compiled_model.evaluate(inputs) - output_gradients = {"output": jnp.ones_like(result["output"])} + out = result["output"] + assert isinstance(out, jnp.ndarray) + output_gradients = {"output": jnp.ones_like(out)} outputs, grads = compiled_model.evaluate_all( params=inputs, output_gradients=output_gradients ) @@ -188,8 +190,10 @@ def test_mymodel_jax_2(): ) inputs = compiled_model.randomize_params() result = compiled_model.evaluate(inputs) - output_gradients = {"output": jnp.ones_like(result["output"])} - outputs, grads = compiled_model.evaluate_all( + out = result["output"] + assert isinstance(out, jnp.ndarray) + output_gradients = {"output": jnp.ones_like(out)} + _, grads = compiled_model.evaluate_all( params=inputs, output_gradients=output_gradients ) # assert_results_equal(outputs, ref_output) diff --git a/tests/scripts/test_metrics.py b/tests/scripts/test_metrics.py index f428a615..e3ed7d1f 100644 --- a/tests/scripts/test_metrics.py +++ b/tests/scripts/test_metrics.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import torch from mithril import TorchBackend, compile from mithril.models import ( @@ -111,10 +112,12 @@ def test_metrics_1(): } for key in expected_results: + res = result[key] + assert isinstance(res, torch.Tensor) if key in result: - np.testing.assert_allclose( - result[key], expected_results[key], atol=TOLERANCE - ) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_2(): @@ -193,7 +196,9 @@ def test_metrics_2(): } for key in result: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_3(): @@ -273,7 +278,9 @@ def test_metrics_3(): } for key in result: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_4(): @@ -353,7 +360,9 @@ def test_metrics_4(): } for key in result: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_5(): @@ -433,7 +442,9 @@ def test_metrics_5(): } for key in result: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_6(): @@ -454,6 +465,7 @@ def test_metrics_6(): params={}, data={"pred": backend.array(pred), "label": backend.array(label)}, )["output"] + assert isinstance(result, torch.Tensor) np.testing.assert_allclose(result, expected_result, atol=TOLERANCE) @@ -475,6 +487,7 @@ def test_metrics_7(): params={}, data={"pred": backend.array(pred), "label": backend.array(label)}, )["output"] + assert isinstance(result, torch.Tensor) np.testing.assert_allclose(result, expected_result, atol=TOLERANCE) @@ -498,6 +511,7 @@ def test_metrics_8(): params={}, data={"pred": backend.array(pred), "label": backend.array(label)}, )["output"] + assert isinstance(result, torch.Tensor) np.testing.assert_allclose(result, expected_result, atol=TOLERANCE) @@ -522,6 +536,8 @@ def test_metrics_9(): data={"pred": backend.array(pred), "label": backend.array(label)}, )["output"] + assert isinstance(result, torch.Tensor) + np.testing.assert_allclose(result, expected_result, atol=TOLERANCE) @@ -549,6 +565,8 @@ def test_metrics_10(): data={"pred": backend.array(pred), "label": backend.array(label)}, )["output"] + assert isinstance(result, torch.Tensor) + np.testing.assert_allclose(result, expected_result, atol=TOLERANCE) @@ -576,6 +594,7 @@ def test_metrics_11(): data={"pred": backend.array(pred), "label": backend.array(label)}, )["output"] + assert isinstance(result, torch.Tensor) np.testing.assert_allclose(result, expected_result, atol=TOLERANCE) @@ -647,7 +666,9 @@ def test_metrics_12(): } for key in expected_results: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_13(): @@ -718,7 +739,10 @@ def test_metrics_13(): } for key in expected_results: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_14(): @@ -789,7 +813,9 @@ def test_metrics_14(): } for key in expected_results: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_15(): @@ -860,7 +886,9 @@ def test_metrics_15(): } for key in expected_results: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_16(): @@ -931,7 +959,9 @@ def test_metrics_16(): } for key in expected_results: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) def test_metrics_17(): @@ -1002,4 +1032,6 @@ def test_metrics_17(): } for key in expected_results: - np.testing.assert_allclose(result[key], expected_results[key], atol=TOLERANCE) + out = result[key] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expected_results[key], atol=TOLERANCE) diff --git a/tests/scripts/test_parallel.py b/tests/scripts/test_parallel.py index d4e8dc25..648d556e 100644 --- a/tests/scripts/test_parallel.py +++ b/tests/scripts/test_parallel.py @@ -23,6 +23,7 @@ import numpy as np import pytest import torch +import torch.distributed import mithril from mithril import compile @@ -354,10 +355,11 @@ def test_torch_parallel_1(): input_parallel = {"input": backend_parallel.array(tensor1, device_mesh=(4,))} result_parallel = pm_parallel.evaluate(params_parallel, input_parallel) + out = result_parallel["output"] + assert isinstance(out, torch.distributed._tensor.DTensor) + assert out._local_tensor.shape == (2, 256) - assert result_parallel["output"]._local_tensor.shape == (2, 256) - - output_full_tensor = result_parallel["output"].full_tensor().cpu() + output_full_tensor = out.full_tensor().cpu() np.testing.assert_allclose(output_full_tensor, (torch.ones(8, 256) * 129)) output_grad = backend.randn(8, 256) @@ -372,7 +374,7 @@ def test_torch_parallel_1(): ) for key, grad in grads.items(): - parallel_grad = grads_parallel.get(key).full_tensor() + parallel_grad = grads_parallel[key].full_tensor() np.testing.assert_allclose(grad.cpu(), parallel_grad.cpu(), 1e-6, 1e-6) @@ -392,8 +394,10 @@ def test_torch_parallel_2(): # Replicate params input = {"input": backend.ones(256, 128, device_mesh=(4, 1))} result = pm.evaluate(params, input) + out = result["output"] + assert isinstance(out, torch.distributed._tensor.DTensor) - output_full_tensor = result["output"].full_tensor() + output_full_tensor = out.full_tensor() np.testing.assert_allclose( output_full_tensor, (torch.ones(256, 256) * 129 + torch.eye(256)) ) @@ -432,7 +436,9 @@ def test_torch_parallel_3(): } result_parallel = pm_parallel.evaluate(params_parallel, input_parallel) - output_full_tensor = result_parallel["output"].full_tensor() + out = result_parallel["output"] + assert isinstance(out, torch.distributed._tensor.DTensor) + output_full_tensor = out.full_tensor() np.testing.assert_allclose( output_full_tensor, (torch.ones(256, 256) * 129 + (torch.arange(4).repeat(64) + 1)), @@ -450,7 +456,7 @@ def test_torch_parallel_3(): ) for key, grad in grads.items(): - parallel_grad = grads_parallel.get(key).full_tensor() + parallel_grad = grads_parallel[key].full_tensor() np.testing.assert_allclose(grad, parallel_grad, 1e-6, 1e-6) @@ -476,7 +482,9 @@ def test_torch_parallel_4(): input_parallel = {"input": backend_parallel.ones(256, 128, device_mesh=(4, 1))} result_parallel = pm_parallel.evaluate(params_parallel, input_parallel) - output_full_tensor = result_parallel["output"].full_tensor() + out = result_parallel["output"] + assert isinstance(out, torch.distributed._tensor.DTensor) + output_full_tensor = out.full_tensor() np.testing.assert_allclose(output_full_tensor, torch.ones(256, 256) * 129 + 3) output_grad = backend.rand(256, 256) @@ -489,9 +497,8 @@ def test_torch_parallel_4(): input_parallel, output_gradients={"output": output_grad_parallel}, ) - for key, grad in grads.items(): - parallel_grad = grads_parallel.get(key).full_tensor() + parallel_grad = grads_parallel[key].full_tensor() np.testing.assert_allclose(grad, parallel_grad, 1e-6, 1e-6) @@ -533,8 +540,13 @@ def test_torch_parallel_5(): output_grads = backend.rand(256, 256) outout_grads_parallel = backend_parallel.array(output_grads) - output_full_tensor = result_parallel["output"].full_tensor() - np.testing.assert_allclose(output_full_tensor, result["output"]) + out_parallel = result_parallel["output"] + out = result["output"] + assert isinstance(out_parallel, torch.distributed._tensor.DTensor) + assert isinstance(out, torch.Tensor) + + output_full_tensor = out_parallel.full_tensor() + np.testing.assert_allclose(output_full_tensor, out) param_grads = pm.evaluate_gradients( params, input, output_gradients={"output": output_grads} @@ -545,7 +557,7 @@ def test_torch_parallel_5(): output_gradients={"output": outout_grads_parallel}, ) for key, grad in param_grads.items(): - parallel_grad = param_grads_parallel.get(key).full_tensor() + parallel_grad = param_grads_parallel[key].full_tensor() np.testing.assert_allclose(grad, parallel_grad, 1e-6, 1e-6) @@ -565,8 +577,11 @@ def test_torch_static_parallel_1(): params = {"b": backend.ones([256])} result = pm.evaluate(params) + out = result["output"] + assert isinstance(out, torch.distributed._tensor.DTensor) + + output_full_tensor = out.full_tensor() - output_full_tensor = result["output"].full_tensor() np.testing.assert_allclose( output_full_tensor, ((torch.ones(256, 256) * 129).sigmoid()) ) @@ -587,8 +602,10 @@ def test_torch_static_parallel_2(): pm = compile(model, backend, jit=False, constant_keys=static_inputs) result = pm.evaluate() + out = result["output"] + assert isinstance(out, torch.distributed._tensor.DTensor) - output_full_tensor = result["output"].full_tensor() + output_full_tensor = out.full_tensor() np.testing.assert_allclose( output_full_tensor, ((torch.ones(256, 256) * 129).sigmoid()) ) @@ -611,9 +628,13 @@ def test_torch_static_parallel_3(): pm = compile(model, backend, jit=False, constant_keys=static_inputs) result = pm.evaluate() + out1 = result["output"] + out2 = result["output2"] + assert isinstance(out1, torch.distributed._tensor.DTensor) + assert isinstance(out2, torch.distributed._tensor.DTensor) - output_full_tensor = result["output"].full_tensor() - output2_full_tensor = result["output2"].full_tensor() + output_full_tensor = out1.full_tensor() + output2_full_tensor = out2.full_tensor() np.testing.assert_allclose( output_full_tensor, ((torch.ones(256, 256) * 129).relu()) ) @@ -805,12 +826,17 @@ def test_torch_parallel_multi_parallel_1(): res1 = pm1.evaluate({}, {"left": left, "right": right}) res2 = pm2.evaluate({}, {"left": left, "right": right}) + out1 = res1["output"] + out2 = res2["output"] + + assert isinstance(out1, torch.distributed._tensor.DTensor) + assert isinstance(out2, torch.distributed._tensor.DTensor) np.testing.assert_allclose( - res1["output"].full_tensor(), + out1.full_tensor(), (left + right).full_tensor(), # type: ignore ) np.testing.assert_allclose( - res2["output"].full_tensor(), + out2.full_tensor(), (left * right).full_tensor(), # type: ignore ) @@ -902,10 +928,11 @@ def test_jax_parallel_1(): input_parallel = {"input": backend_parallel.array(tensor1, device_mesh=(4,))} result_parallel = pm_parallel.evaluate(params_parallel, input_parallel) - - assert result_parallel["output"].sharding.shape == (4, 1) + out = result_parallel["output"] + assert isinstance(out, jax.Array) output_full_tensor = result_parallel["output"] + assert isinstance(output_full_tensor, jax.numpy.ndarray) np.testing.assert_allclose(output_full_tensor, (jax.numpy.ones((8, 256)) * 129)) output_grad = backend.randn(8, 256) @@ -919,8 +946,8 @@ def test_jax_parallel_1(): output_gradients={"output": output_grad_parallel}, ) - for key, _grad in grads: - parallel_grad = grads_parallel.get(key) + for _key, _grad in grads.items(): + parallel_grad = grads_parallel[_key] np.testing.assert_allclose(_grad, parallel_grad, 1e-5, 1e-5) @@ -943,6 +970,7 @@ def test_jax_parallel_2(): result = pm.evaluate(params, input) output_full_tensor = result["output"] + assert isinstance(output_full_tensor, jax.numpy.ndarray) np.testing.assert_allclose( output_full_tensor, (jax.numpy.ones((256, 256)) * 129 + jax.numpy.eye(256)) ) @@ -984,6 +1012,7 @@ def test_jax_parallel_3(): result_parallel = pm_parallel.evaluate(params_parallel, input_parallel) output_full_tensor = result_parallel["output"] + assert isinstance(output_full_tensor, jax.numpy.ndarray) np.testing.assert_allclose( output_full_tensor, ( @@ -1004,7 +1033,7 @@ def test_jax_parallel_3(): ) for key, grad in grads.items(): - parallel_grad = grads_parallel.get(key) + parallel_grad = grads_parallel[key] np.testing.assert_allclose(grad, parallel_grad, 1e-5, 1e-5) @@ -1032,6 +1061,7 @@ def test_jax_parallel_4(): result_parallel = pm_parallel.evaluate(params_parallel, input_parallel) output_full_tensor = result_parallel["output"] + assert isinstance(output_full_tensor, jax.numpy.ndarray) np.testing.assert_allclose( output_full_tensor, jax.numpy.ones((256, 256)) * 129 + 3 ) @@ -1048,7 +1078,7 @@ def test_jax_parallel_4(): ) for key, grad in grads.items(): - parallel_grad = grads_parallel.get(key) + parallel_grad = grads_parallel[key] np.testing.assert_allclose(grad, parallel_grad, 1e-5, 1e-5) @@ -1092,7 +1122,10 @@ def test_jax_parallel_5(): output_grads = backend.randn(256, 256) outout_grads_parallel = backend_parallel.array(output_grads) output_full_tensor = result_parallel["output"] - np.testing.assert_allclose(output_full_tensor, result["output"]) + out = result["output"] + assert isinstance(out, jax.numpy.ndarray) + assert isinstance(output_full_tensor, jax.numpy.ndarray) + np.testing.assert_allclose(output_full_tensor, out) param_grads = pm.evaluate_gradients( params, input, output_gradients={"output": output_grads} @@ -1103,5 +1136,5 @@ def test_jax_parallel_5(): output_gradients={"output": outout_grads_parallel}, ) for key, grad in param_grads.items(): - parallel_grad = param_grads_parallel.get(key) + parallel_grad = param_grads_parallel[key] np.testing.assert_allclose(grad, parallel_grad, 1e-5, 1e-5) diff --git a/tests/scripts/test_recurrent_models.py b/tests/scripts/test_recurrent_models.py index e9a484cf..c5075dd9 100644 --- a/tests/scripts/test_recurrent_models.py +++ b/tests/scripts/test_recurrent_models.py @@ -784,6 +784,7 @@ def test_torch_encoder_decoder_var_seq_len(): ) for idx in range(max_seq_len): mithril_output = outputs[f"output{idx}"] + assert isinstance(mithril_output, torch.Tensor) torch.testing.assert_close( mithril_output, output[: mithril_output.shape[0], idx : idx + 1, :], @@ -796,7 +797,7 @@ def test_torch_encoder_decoder_var_seq_len(): loss.backward() for key, value in torch_model.named_parameters(): - grad = gradients_mithril[key_map.get(key)] + grad = gradients_mithril[key_map[key]] grad = grad.permute(*torch.arange(grad.ndim - 1, -1, -1)) torch.testing.assert_close(grad, value.grad) diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index a2a3108e..e0c929b3 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -1386,6 +1386,7 @@ def test_check_static_1(): # "b": np.array([3.0])} outputs = comp_model.evaluate() ref_out = outputs["output"] + assert isinstance(ref_out, np.ndarray) np.testing.assert_array_equal(ref_out, np.array([[26.0], [27.0]])) @@ -1398,6 +1399,7 @@ def test_check_static_2(): inputs = {"w": np.array([[4.0], [5.0]]), "b": np.array([3.0])} outputs = comp_model.evaluate(inputs) ref_out = outputs["output"] + assert isinstance(ref_out, np.ndarray) np.testing.assert_array_equal(ref_out, np.array([[26.0], [27.0]])) @@ -1410,6 +1412,7 @@ def test_check_static_3(): inputs = {"b": np.array([3.0])} outputs = comp_model.evaluate(inputs) ref_out = outputs["output"] + assert isinstance(ref_out, np.ndarray) np.testing.assert_array_equal(ref_out, np.array([[26.0], [27.0]])) @@ -1429,6 +1432,7 @@ def test_check_static_4(): ) outputs = comp_model.evaluate() ref_out = outputs["output"] + assert isinstance(ref_out, np.ndarray) np.testing.assert_array_equal(ref_out, np.array([[26.0], [27.0]])) @@ -1451,6 +1455,7 @@ def test_check_static_5(): outputs = comp_model.evaluate(data=data) ref_out = outputs["output"] + assert isinstance(ref_out, np.ndarray) np.testing.assert_array_equal(ref_out, np.array([[26.0], [27.0]])) @@ -1475,6 +1480,7 @@ def test_check_static_6(): outputs = comp_model.evaluate(data=data) ref_out = outputs["output"] + assert isinstance(ref_out, np.ndarray) np.testing.assert_array_equal(ref_out, np.array([[26.0], [27.0]])) @@ -1582,8 +1588,8 @@ def test_batch_minibatch_grad(): batch_grad_results = pm.evaluate_gradients( inputs, data={"input": backend_input, "target": backend_target} ) - minibatch_result = [] - minibatch_grad_result = [] + minibatch_result: list[dict] = [] + minibatch_grad_result: list[dict] = [] # Split into minibatches for idx in range(8): @@ -1601,7 +1607,8 @@ def test_batch_minibatch_grad(): "target": backend_target[idx : idx + 1], }, ) - minibatch_result.append(result) + assert isinstance(result["final_cost"], torch.Tensor) + minibatch_result.append(result) # type: ignore minibatch_grad_result.append(grad_result) minibatch_cost = sum([minibatch_result[i]["final_cost"] for i in range(8)]) / 8 @@ -1610,6 +1617,7 @@ def test_batch_minibatch_grad(): for key in minibatch_grad_result[0] } batch_cost = batch_result["final_cost"] + assert isinstance(batch_cost, torch.Tensor) assert np.isclose(minibatch_cost, batch_cost, rtol=1e-6, atol=1e-6) assert list(batch_grad_results.keys()) == list(minibatch_grads.keys()) for key in batch_grad_results: @@ -1823,8 +1831,8 @@ def test_arange_primitive(): params = {"b1": _backend.ones(1), "w1": _backend.ones((3, 1))} data = {"input": _backend.ones((1, 3))} output = pm.evaluate(params, data) - assert (output["arange_res"] == _backend.arange(arange_len)).all() - assert output["arange_res"].dtype == _backend.arange(arange_len).dtype + assert (output["arange_res"] == _backend.arange(arange_len)).all() # type: ignore + assert output["arange_res"].dtype == _backend.arange(arange_len).dtype # type: ignore def test_to_tensor_primitive(): @@ -1867,8 +1875,8 @@ def test_to_tensor_primitive(): params = {"b1": _backend.ones(1), "w1": _backend.ones((3, 1))} data = {"input": _backend.ones((1, 3))} output = pm.evaluate(params, data) - assert (output["power_out"] == _backend.array([9])).all() - assert output["power_out"].dtype == _backend.array([9]).dtype + assert (output["power_out"] == _backend.array([9])).all() # type: ignore + assert output["power_out"].dtype == _backend.array([9]).dtype # type: ignore def test_shapes_1(): @@ -1998,10 +2006,10 @@ def test_static_concat(): pm = mithril.compile( model=model, backend=backend, constant_keys={"input": backend.zeros(1)} ) + out = pm.evaluate()["output"] + assert isinstance(out, np.ndarray) - assert all( - pm.evaluate()["output"] == backend.array([0.0, 0.0], dtype=mithril.float32) - ) + assert all(out == backend.array([0.0, 0.0], dtype=mithril.float32)) def test_reduce_overlap_shapes(): @@ -3163,17 +3171,21 @@ def test_arange_1(): expected_result = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) m += Arange(0, 10, 1)(output="output") - backends: list[type[Backend]] = [TorchBackend, JaxBackend, NumpyBackend, MlxBackend] + backends: list[ + type[JaxBackend] | type[TorchBackend] | type[NumpyBackend] | type[MlxBackend] + ] = [TorchBackend, JaxBackend, NumpyBackend, MlxBackend] for backend_class in backends: if backend_class.is_installed: backend = backend_class(precision=32) cm = compile( - m, backend, inference=True + m, + backend, + inference=True, # type: ignore ) # Inference set to True since no gradients exist for integer type output # of Arange! - np.testing.assert_allclose( - expected_result, cm.evaluate({})["output"], rtol=1e-6, atol=1e-6 - ) + out = cm.evaluate({})["output"] + assert isinstance(out, backend.DataType) + np.testing.assert_allclose(expected_result, out, rtol=1e-6, atol=1e-6) def test_arange_2(): @@ -3187,7 +3199,10 @@ def test_arange_2(): backend = backend_class(precision=32) cm = compile(m, backend) np.testing.assert_allclose( - expected_result, cm.evaluate({})["output"], rtol=1e-6, atol=1e-6 + expected_result, + cm.evaluate({})["output"], # type: ignore + rtol=1e-6, + atol=1e-6, ) @@ -3196,14 +3211,16 @@ def test_arange_3(): expected_result = np.array([0.1, 0.7, 1.3, 1.9, 2.5, 3.1, 3.7]) m += Arange(0.1, 4, 0.6)(output="output") - backends: list[type[Backend]] = [TorchBackend, JaxBackend, NumpyBackend, MlxBackend] + backends: list[ + type[TorchBackend] | type[JaxBackend] | type[NumpyBackend] | type[MlxBackend] + ] = [TorchBackend, JaxBackend, NumpyBackend, MlxBackend] for backend_class in backends: if backend_class.is_installed: backend = backend_class(precision=32) - cm = compile(m, backend) - np.testing.assert_allclose( - expected_result, cm.evaluate({})["output"], rtol=1e-6, atol=1e-6 - ) + cm = compile(m, backend) # type: ignore + out = cm.evaluate({})["output"] + assert isinstance(out, backend.DataType) + np.testing.assert_allclose(expected_result, out, rtol=1e-6, atol=1e-6) def test_size(): @@ -3235,7 +3252,7 @@ def test_size(): cm = compile(model, backend, data_keys={"input"}, inference=True) np.testing.assert_allclose( expected_result, - cm.evaluate(data={"input": backend.array(input_array)})["output"], + cm.evaluate(data={"input": backend.array(input_array)})["output"], # type: ignore rtol=1e-6, atol=1e-6, ) @@ -3473,6 +3490,8 @@ def test_evaluate_all_2(): assert eval_out.keys() == eval_all_out[0].keys() for val1, val2 in zip(eval_out.values(), eval_all_out[0].values(), strict=False): + assert isinstance(val1, backend.DataType) + assert isinstance(val2, backend.DataType) np.testing.assert_allclose(val1, val2, rtol=1e-7, atol=1e-7) assert eval_grad_out.keys() == eval_all_out[1].keys() @@ -4148,8 +4167,12 @@ def test_mlp_last_dimension_prop_2(): comp_model = mithril.compile(model=ctx, backend=NumpyBackend()) inputs = {"in1": np.array([3.0]), "in2": np.array([2.0])} outputs = comp_model.evaluate(inputs) - np.testing.assert_allclose(outputs["final_cost"], np.array(3.0)) - np.testing.assert_allclose(outputs["output"], np.array(5.0)) + output_final_cost = outputs["final_cost"] + out = outputs["output"] + assert isinstance(output_final_cost, np.ndarray) + assert isinstance(out, np.ndarray) + np.testing.assert_allclose(output_final_cost, np.array(3.0)) + np.testing.assert_allclose(out, np.array(5.0)) def test_connect_8(): @@ -4579,7 +4602,9 @@ def test_cycle_handling_1(): } res = compiled_model.evaluate(inputs) - np.testing.assert_allclose(res["output"], expceted_result, rtol=1e-14, atol=1e-14) + out = res["output"] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expceted_result, rtol=1e-14, atol=1e-14) assert_connections(compiled_model, expected_connections) @@ -4703,7 +4728,9 @@ def test_cycle_handling_2(): ) res = compiled_model.evaluate(inputs) - np.testing.assert_allclose(res["output"], expceted_result, rtol=1e-14, atol=1e-14) + out = res["output"] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expceted_result, rtol=1e-14, atol=1e-14) assert_connections(compiled_model, expected_connections) @@ -4846,7 +4873,9 @@ def test_cycle_handling_3(): ) res = compiled_model.evaluate(inputs) - np.testing.assert_allclose(res["output"], expceted_result, rtol=1e-14, atol=1e-14) + out = res["output"] + assert isinstance(out, torch.Tensor) + np.testing.assert_allclose(out, expceted_result, rtol=1e-14, atol=1e-14) assert_connections(compiled_model, expected_connections) @@ -6160,20 +6189,24 @@ def test_to_tensor(): # Test for torch pm_torch = compile(model, TorchBackend(precision=64)) result_torch = pm_torch.evaluate({}, {"input": input1})["output"] + assert isinstance(result_torch, torch.Tensor) expected_torch = torch.tensor(input1, dtype=torch.float64) np.testing.assert_allclose(result_torch, expected_torch, 1e-12) result_torch = pm_torch.evaluate({}, {"input": input2})["output"] + assert isinstance(result_torch, torch.Tensor) expected_torch = torch.tensor(input2, dtype=torch.bool) assert (result_torch == expected_torch).all() # Test for Jax pm_jax = compile(model, JaxBackend(precision=64), jit=False) result = pm_jax.evaluate({}, {"input": input1})["output"] + assert isinstance(result, jax.numpy.ndarray) expected = jax.numpy.array(input1, jax.numpy.float64) np.testing.assert_allclose(result, expected, 1e-12) result = pm_jax.evaluate({}, {"input": input2})["output"] + assert isinstance(result, jax.numpy.ndarray) expected = jax.numpy.array(input2, dtype=jax.numpy.bool_) assert (result == expected).all() @@ -6181,20 +6214,24 @@ def test_to_tensor(): if platform.system() == "Darwin": pm_mlx = compile(model, MlxBackend(precision=32)) result_mlx = pm_mlx.evaluate({}, {"input": input1})["output"] + assert isinstance(result_mlx, mx.array) expected_mlx = mx.array(input1, mx.float32) np.testing.assert_allclose(result_mlx, expected_mlx, 1e-6) # type: ignore - result = pm_mlx.evaluate({}, {"input": input2})["output"] + result_mlx = pm_mlx.evaluate({}, {"input": input2})["output"] + assert isinstance(result_mlx, mx.array) expected = mx.array(input2, dtype=mx.bool_) # type: ignore - assert (result == expected).all() + assert (result_mlx == expected).all() # type: ignore # Test for Numpy pm_numpy = compile(model, NumpyBackend(precision=64), jit=False) result_numpy = pm_numpy.evaluate({}, {"input": input1})["output"] + assert isinstance(result_numpy, np.ndarray) expected_numpy = np.array(input1, np.float64) np.testing.assert_allclose(result_numpy, expected_numpy, 1e-12) result_numpy = pm_numpy.evaluate({}, {"input": input2})["output"] + assert isinstance(result_numpy, np.ndarray) expected_numpy = np.array(input2, dtype=np.bool_) assert (result_numpy == expected_numpy).all() @@ -6444,6 +6481,7 @@ def test_numpy_type_promotion_1(): ) for output in outputs.values(): + assert isinstance(output, np.ndarray) assert output.dtype == np.float16 @@ -6478,6 +6516,7 @@ def test_numpy_type_promotion_2(): ) for output in outputs.values(): + assert isinstance(output, np.ndarray) assert output.dtype == np.float32 @@ -6506,6 +6545,7 @@ def test_numpy_type_promotion_3(): outputs = pm.evaluate() for output in outputs.values(): + assert isinstance(output, np.ndarray) assert output.dtype == np.float16 @@ -6530,8 +6570,9 @@ def test_numpy_type_promotion_4(): pm = compile( model, backend=backend, jit=False, constant_keys={"left": left, "right": right} ) + from typing import Any - outputs = pm.evaluate() + outputs: dict[str, np.ndarray[Any, Any]] = pm.evaluate() # type: ignore for output in outputs.values(): assert output.dtype == np.float32 @@ -6568,6 +6609,7 @@ def test_numpy_type_promotion_5(): outputs = pm.evaluate({}, {"left": np.ones((3, 3), dtype=np.int16)}) for output in outputs.values(): + assert isinstance(output, np.ndarray) assert output.dtype == np.float16 @@ -6717,7 +6759,9 @@ def test_constant_1(): expected = np.array( [epsilon_table[precision][Constant.EPSILON]] * 2, dtype=np.float64 ) - np.testing.assert_almost_equal(pm.evaluate()["out"], expected, 20) + out = pm.evaluate()["out"] + assert isinstance(out, np.ndarray) + np.testing.assert_almost_equal(out, expected, 20) def test_constant_2(): @@ -6732,7 +6776,9 @@ def test_constant_2(): expected = np.array( [epsilon_table[precision][Constant.EPSILON]] * 2, dtype=np.float64 ) - np.testing.assert_almost_equal(pm.evaluate()["out"], expected, 20) + out = pm.evaluate()["out"] + assert isinstance(out, np.ndarray) + np.testing.assert_almost_equal(out, expected, 20) def test_constant_3(): @@ -6745,7 +6791,9 @@ def test_constant_3(): expected = np.array( [epsilon_table[precision][Constant.EPSILON]] * 2, dtype=np.float32 ) - np.testing.assert_almost_equal(pm.evaluate()["out"], expected, 20) + out = pm.evaluate()["out"] + assert isinstance(out, np.ndarray) + np.testing.assert_almost_equal(out, expected, 20) def test_constant_4(): @@ -6760,7 +6808,9 @@ def test_constant_4(): expected = np.array( [epsilon_table[precision][Constant.EPSILON]] * 2, dtype=np.float32 ) - np.testing.assert_almost_equal(pm.evaluate()["out"], expected, 20) + out = pm.evaluate()["out"] + assert isinstance(out, np.ndarray) + np.testing.assert_almost_equal(out, expected, 20) def test_constant_5(): diff --git a/tests/scripts/test_train_context.py b/tests/scripts/test_train_context.py index 0167fced..1f5106d1 100644 --- a/tests/scripts/test_train_context.py +++ b/tests/scripts/test_train_context.py @@ -335,10 +335,12 @@ def test_add_metric_1(): input = backend.randn(5, 5) c_model = mithril.compile(ctx2, backend, data_keys={"input"}) result = c_model.evaluate({}, {"input": input}) + res_metric = result["metric"] + assert isinstance(res_metric, np.ndarray) assert "metric" in ctx2.output_keys - assert result["metric"].shape == input.shape - np.testing.assert_almost_equal(result["metric"], np.where(input > 0, input, 0)) + assert res_metric.shape == input.shape + np.testing.assert_almost_equal(res_metric, np.where(input > 0, input, 0)) def test_add_metric_2(): @@ -359,8 +361,10 @@ def test_add_metric_2(): expected_metric = np.array(np.mean(np.where(input > 0, input, 0))) assert "metric" in ctx2.output_keys - assert result["metric"].shape == expected_metric.shape - np.testing.assert_almost_equal(result["metric"], expected_metric) + res_metric = result["metric"] + assert isinstance(res_metric, np.ndarray) + assert res_metric.shape == expected_metric.shape + np.testing.assert_almost_equal(res_metric, expected_metric) def test_add_regularization_case_1(): @@ -483,13 +487,13 @@ def test_autogenerated_key_regularization_integrated_linear_9(): result = comp_train_model.evaluate(params, data) gradients = comp_train_model.evaluate_gradients(params, data=data) + res_out = result["output"] + res_cost = result["final_cost"] + assert isinstance(res_out, np.ndarray) + assert isinstance(res_cost, np.ndarray) - assert ( - backend.abs(result["output"] - backend.array([[0.5], [0.8], [1.0]])) <= 1e-14 - ).all() - assert ( - result["final_cost"] - 1.996666666666666666666666666666666666666667 - ) <= 1e-14 + assert (backend.abs(res_out - backend.array([[0.5], [0.8], [1.0]])) <= 1e-14).all() + assert (res_cost - 1.996666666666666666666666666666666666666667) <= 1e-14 assert ( backend.abs( gradients["w"] @@ -542,11 +546,15 @@ def test_autogenerated_key_regularization_integrated_nn_7_regex(): ) result = comp_train_model.evaluate(params, data) + res_final_cost = result["final_cost"] + res_out = result["output"] + assert isinstance(res_final_cost, np.ndarray) + assert isinstance(res_out, np.ndarray) gradients = comp_train_model.evaluate_gradients(params, data=data) assert ( backend.abs( - result["output"] + res_out - backend.array( [ [ @@ -558,7 +566,7 @@ def test_autogenerated_key_regularization_integrated_nn_7_regex(): ) <= 1e-14 ).all() - assert (result["final_cost"] - 11.883655622163706) <= 1e-14 + assert (res_final_cost - 11.883655622163706) <= 1e-14 assert ( backend.abs( gradients["w0"] diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index f8dcee32..c46893d2 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -15,6 +15,7 @@ from types import UnionType from typing import Any +import jax.numpy as jnp import pytest import mithril @@ -366,11 +367,14 @@ def test_tuple_conversion_2(): params = pm_1.randomize_params() eval_1 = pm_1.evaluate(params=params) eval_2 = pm_2.evaluate(params=params) - output_gradients = {"output": backend.randn(eval_1["output"].shape)} + eval_1_output = eval_1["output"] + assert isinstance(eval_1_output, jnp.ndarray) + output_gradients = {"output": backend.randn(eval_1_output.shape)} grad_1 = pm_1.evaluate_gradients(params=params, output_gradients=output_gradients) grad_2 = pm_2.evaluate_gradients(params=params, output_gradients=output_gradients) # Check outputs for key, value in eval_1.items(): + assert isinstance(value, jnp.ndarray) assert (value == eval_2[key]).all() # Check gradients. for key, value in grad_1.items(): @@ -409,11 +413,14 @@ def test_tuple_conversion_3(): params = pm_1.randomize_params() eval_1 = pm_1.evaluate(params=params) eval_2 = pm_2.evaluate(params=params) - output_gradients = {"output": backend.randn(eval_1["output"].shape)} + out = eval_1["output"] + assert isinstance(out, jnp.ndarray) + output_gradients = {"output": backend.randn(out.shape)} grad_1 = pm_1.evaluate_gradients(params=params, output_gradients=output_gradients) grad_2 = pm_2.evaluate_gradients(params=params, output_gradients=output_gradients) # Check outputs for key, value in eval_1.items(): + assert isinstance(value, jnp.ndarray) assert (value == eval_2[key]).all() # Check gradients. for key, value in grad_1.items(): @@ -452,11 +459,14 @@ def test_list_conversion_1(): params = pm_1.randomize_params() eval_1 = pm_1.evaluate(params=params) eval_2 = pm_2.evaluate(params=params) - output_gradients = {"output": backend.randn(eval_1["output"].shape)} + out = eval_1["output"] + assert isinstance(out, jnp.ndarray) + output_gradients = {"output": backend.randn(out.shape)} grad_1 = pm_1.evaluate_gradients(params=params, output_gradients=output_gradients) grad_2 = pm_2.evaluate_gradients(params=params, output_gradients=output_gradients) # Check outputs for key, value in eval_1.items(): + assert isinstance(value, jnp.ndarray) assert (value == eval_2[key]).all() # Check gradients. for key, value in grad_1.items(): @@ -494,11 +504,14 @@ def test_nested_list_conversion_1(): params = pm_1.randomize_params() eval_1 = pm_1.evaluate(params=params) eval_2 = pm_2.evaluate(params=params) - output_gradients = {"output": backend.randn(eval_1["output"].shape)} + out = eval_1["output"] + assert isinstance(out, jnp.ndarray) + output_gradients = {"output": backend.randn(out.shape)} grad_1 = pm_1.evaluate_gradients(params=params, output_gradients=output_gradients) grad_2 = pm_2.evaluate_gradients(params=params, output_gradients=output_gradients) # Check outputs for key, value in eval_1.items(): + assert isinstance(value, jnp.ndarray) assert (value == eval_2[key]).all() # Check gradients. for key, value in grad_1.items(): @@ -537,15 +550,23 @@ def test_nested_list_conversion_2(): params = pm_1.randomize_params() eval_1 = pm_1.evaluate(params=params) eval_2 = pm_2.evaluate(params=params) - output_gradients = {"output": backend.randn(eval_1["output"].shape)} + out = eval_1["output"] + assert isinstance(out, jnp.ndarray) + output_gradients = {"output": backend.randn(out.shape)} grad_1 = pm_1.evaluate_gradients(params=params, output_gradients=output_gradients) grad_2 = pm_2.evaluate_gradients(params=params, output_gradients=output_gradients) # Check outputs for key, value in eval_1.items(): - assert (value == eval_2[key]).all() + assert isinstance(value, jnp.ndarray) + value2 = eval_2[key] + assert isinstance(value2, jnp.ndarray) + assert (value == value2).all() # Check gradients. for key, value in grad_1.items(): - assert (value == grad_2[key]).all() + assert isinstance(value, jnp.ndarray) + value2 = grad_2[key] + assert isinstance(value2, jnp.ndarray) + assert (value == value2).all() def test_type_propagation_1(): diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 9f411d57..a8d9bd1a 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -16,6 +16,7 @@ import numpy as np import pytest +import torch import mithril from mithril.framework.common import NOT_GIVEN, ConnectionType @@ -365,11 +366,13 @@ def test_type_15(): results = pm.evaluate() expected_result = np.array(backend.sigmoid(backend.array([1.0, 2.0]))) + out = results["output"] + assert isinstance(out, torch.Tensor) + out2 = results["output2"] + assert isinstance(out2, torch.Tensor) - np.testing.assert_allclose(results["output"], expected_result, rtol=1e-6, atol=1e-6) - np.testing.assert_allclose( - results["output2"], expected_result, rtol=1e-6, atol=1e-6 - ) + np.testing.assert_allclose(out, expected_result, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(out2, expected_result, rtol=1e-6, atol=1e-6) def test_type_16(): diff --git a/tests/utils.py b/tests/utils.py index 6945b98d..77409b1f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -62,7 +62,7 @@ def check_evaluations( outs_2 = pm_2.evaluate(params_2) assert outs_1.keys() == outs_2.keys(), "Keys are not same!" for key, out in outs_1.items(): - assert (outs_2[key] == out).all(), f"Output value for '{key}' key is not equal!" + assert (outs_2[key] == out).all(), f"Output value for '{key}' key is not equal!" # type: ignore # Check gradients. if not inference: From db37cf8222698b38c289bf8b372ae310a951b423 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Wed, 4 Dec 2024 17:58:33 +0300 Subject: [PATCH 10/26] reviews are applied except ones that includes codegen --- mithril/framework/codegen/utils.py | 7 ++++--- mithril/framework/utils.py | 2 +- tests/json_files/randomized_model_tests_all_backends.json | 6 +----- tests/scripts/helper.py | 5 ----- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/mithril/framework/codegen/utils.py b/mithril/framework/codegen/utils.py index 1975aeeb..1752c772 100644 --- a/mithril/framework/codegen/utils.py +++ b/mithril/framework/codegen/utils.py @@ -14,16 +14,17 @@ import ast import keyword -from typing import Any -from ...backends.backend import Backend +from ...backends.backend import Backend, DataType from ..common import ShapeNode key_map_type = dict[str, str] # TODO: This name misleads -def partial_array_creation_func(backend: Backend[Any], formula_key: str) -> ast.stmt: +def partial_array_creation_func( + backend: Backend[DataType], formula_key: str +) -> ast.stmt: kwargs = [ast.keyword(arg="precision", value=ast.Constant(value=backend.precision))] # We don't need device in manulgrad(Numpy) diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 76470948..7a9a8043 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -393,7 +393,7 @@ def find_intersection_type( return None -def find_type(connection: Any) -> type[Any]: +def find_type[T](connection: T) -> type[T]: if isinstance(connection, tuple | list): element_types: list[Any] = [find_type(elem) for elem in connection] if isinstance(connection, tuple): diff --git a/tests/json_files/randomized_model_tests_all_backends.json b/tests/json_files/randomized_model_tests_all_backends.json index d31db327..34ed6750 100644 --- a/tests/json_files/randomized_model_tests_all_backends.json +++ b/tests/json_files/randomized_model_tests_all_backends.json @@ -1551,7 +1551,7 @@ "name": "OneToMany", "differentiability_info": {"input": true, "initial_hidden": true}, "regular_args": { - "max_sequence_length": 6, + "max_sequence_length": 5, "cell_type": "RNNCell" } }, @@ -1573,10 +1573,6 @@ }, "target4": { "shapes": [[4, 4], [1,1], [15 , 15]] - }, - - "target5": { - "shapes": [[2, 2], [1,1], [15 , 15]] } }, "iterations": 20 diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 24ac23b8..29d789ab 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -265,8 +265,3 @@ def assert_evaluations_equal(model1, model2, backend, static_keys): assert list(output_base.keys()) == list(output_recreated.keys()) for key in output_base: assert backend.abs(output_base[key] - output_recreated[key]).all() < 1e-14 # type: ignore - - -class TensorMock: - def __init__(self, value) -> None: - self.value = value From 3be2f7d32adf590b94118afe7e394547a4eac089 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 5 Dec 2024 10:56:20 +0300 Subject: [PATCH 11/26] chore: type annotations in codegen are fixed --- examples/gpt/run_sample.py | 2 +- .../with_manualgrad/numpy_backend/backend.py | 4 +- mithril/framework/codegen/c_gen.py | 2 +- mithril/framework/codegen/code_gen.py | 10 +- mithril/framework/codegen/numpy_gen.py | 49 +++-- mithril/framework/codegen/python_gen.py | 198 +++++++++++++----- mithril/framework/codegen/torch_gen.py | 11 +- mithril/framework/common.py | 35 +++- mithril/framework/constraints.py | 14 +- mithril/framework/physical/model.py | 80 +++---- 10 files changed, 265 insertions(+), 140 deletions(-) diff --git a/examples/gpt/run_sample.py b/examples/gpt/run_sample.py index 47397105..231611f0 100644 --- a/examples/gpt/run_sample.py +++ b/examples/gpt/run_sample.py @@ -155,7 +155,7 @@ def generate( logits = logits[:, -1, :] / temperature # type: ignore # Optionally crop the logits to only the top k options if top_k is not None: - v = model.backend.topk(logits, min(top_k, logits.shape[-1])) + v = model.backend.topk(logits, min(top_k, logits.shape[-1])) # type: ignore logits = model.backend.where( logits < v[:, [-1]], -model.backend.inf, logits ) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index e3d705a3..3be576e2 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -24,7 +24,7 @@ from . import ops, ops_grad, utils -class NumpyBackend(Backend[np.ndarray]): +class NumpyBackend(Backend[np.ndarray[Any, Any]]): """A backend implementation for the Mithril library using NumPy with manual gradient support. @@ -41,7 +41,7 @@ class NumpyBackend(Backend[np.ndarray]): registered_primitives = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" - registered_primitives_grad_fn: dict[str, Callable] = {} + registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} def __init__(self, device: str = "cpu", precision: int = 32) -> None: self._precision = precision diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index e595711c..61b13c6b 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -30,7 +30,7 @@ FinalCost = "final_cost" -class CGen(CodeGen): +class CGen(CodeGen[PyArray]): BACKWARD_FN_SUFFIX = "_grad" def __init__(self, pm: PhysicalModel[PyArray]) -> None: diff --git a/mithril/framework/codegen/code_gen.py b/mithril/framework/codegen/code_gen.py index 6281ce8f..e556a584 100644 --- a/mithril/framework/codegen/code_gen.py +++ b/mithril/framework/codegen/code_gen.py @@ -13,11 +13,11 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections.abc import Callable -from typing import Any, Generic +from typing import Generic from mithril import DataType +from ..common import EvaluateAllType, EvaluateGradientsType, EvaluateType from ..physical.model import PhysicalModel @@ -34,5 +34,9 @@ def generate_code(self, file_path: str | None = None) -> None: @abstractmethod def compile_code( self, jit: bool - ) -> tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: + ) -> tuple[ + EvaluateType[DataType], + EvaluateGradientsType[DataType] | None, + EvaluateAllType[DataType] | None, + ]: raise NotImplementedError("compile_code is not implemented") diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index 3cbe3e81..e0748107 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -20,8 +20,6 @@ import numpy as np -from mithril import DataType - from ...backends.with_manualgrad.numpy_backend import NumpyBackend from ...core import Dtype from ...framework.physical.model import PhysicalModel @@ -31,21 +29,25 @@ prepare_function_args, ) from ..common import ( + DataEvalType, + EvaluateAllType, + EvaluateGradientsType, + EvaluateType, FinalCost, LossKey, + ParamsEvalType, Tensor, - ValueType, is_type_adjustment_required, ) from ..logical import PrimitiveModel, Scalar -from .python_gen import PythonCodeGen +from .python_gen import PythonCodeGen, RawGradientType from .utils import check_repr_inequality -class NumpyCodeGen(PythonCodeGen[DataType]): +class NumpyCodeGen(PythonCodeGen[np.ndarray[Any, Any]]): BACKWARD_FN_SUFFIX = "_grad" - def __init__(self, pm: PhysicalModel[DataType]) -> None: + def __init__(self, pm: PhysicalModel[np.ndarray[Any, Any]]) -> None: super().__init__(pm) assert isinstance(self.pm.backend, NumpyBackend) @@ -92,22 +94,26 @@ def generate_imports(self): def compile_code( self, jit: bool = False - ) -> tuple[Callable[..., Any], Callable[..., Any], Callable[..., Any]]: + ) -> tuple[ + EvaluateType[np.ndarray[Any, Any]], + EvaluateGradientsType[np.ndarray[Any, Any]] | None, + EvaluateAllType[np.ndarray[Any, Any]] | None, + ]: eval_fn, grad_fn = self.exec_generated_code() # TODO: Not looks good, and looks over complicated! def evaluate_gradients_wrapper_manualgrad( - params: dict[str, np.ndarray[Any, Any]], - data: dict[str, np.ndarray[Any, Any] | ValueType] | None = None, - output_gradients: dict[str, np.ndarray[Any, Any]] | None = None, + params: ParamsEvalType[np.ndarray[Any, Any]] | None = None, + data: DataEvalType[np.ndarray[Any, Any]] | None = None, + output_gradients: ParamsEvalType[np.ndarray[Any, Any]] | None = None, *, - grad_fn: Callable[..., Any], + grad_fn: RawGradientType[np.ndarray[Any, Any]], include_output: bool = False, ) -> ( - dict[str, np.ndarray[Any, Any]] + DataEvalType[np.ndarray[Any, Any]] | tuple[ - dict[str, np.ndarray[Any, Any]], - dict[str, np.ndarray[Any, Any] | dict[str, np.ndarray[Any, Any]]], + DataEvalType[np.ndarray[Any, Any]], + ParamsEvalType[np.ndarray[Any, Any]], ] ): if params is None: @@ -118,8 +124,6 @@ def evaluate_gradients_wrapper_manualgrad( # If evaluate_gradients called directly, first call evaluate. cached_data = self.pm.data_store.data_values - if data is None: - data = {} output: dict[str, np.ndarray[Any, Any]] = eval_fn( params=params, data=data, cache=cached_data ) @@ -133,7 +137,7 @@ def evaluate_gradients_wrapper_manualgrad( ): key_cache = cached_data.get(key + "_cache", {}) assert isinstance(key_cache, dict) - out_data: np.ndarray | None = None + out_data: np.ndarray[Any, Any] | None = None if key in params: out_data = params[key] elif "output" in key_cache: @@ -183,7 +187,7 @@ def evaluate_gradients_wrapper_manualgrad( if grad_fn is not None: grad_fn = partial(evaluate_gradients_wrapper_manualgrad, grad_fn=grad_fn) - return self.post_process_fns(eval_fn, grad_fn, jit) + return self.post_process_fns(eval_fn, grad_fn, jit) # type: ignore def get_primitive_details(self, output_key: str): model = self.pm._flat_graph.get_model(output_key) @@ -197,7 +201,7 @@ def get_primitive_details(self, output_key: str): def call_primitive( self, model: PrimitiveModel, - fn: Callable, + fn: Callable[..., Any], l_input_keys: list[str], g_input_keys: list[str], output_key: str, @@ -268,7 +272,7 @@ def generate_evaluate_gradients( ) -> ast.FunctionDef: input_body: list[ast.stmt] = [] function_body: list[ast.stmt] = [] - used_keys = set() + used_keys: set[str] = set() all_ignored_keys = ( ignore_grad_keys @@ -437,7 +441,10 @@ def generate_evaluate_gradients( ) idx_arg = ast.Constant(value=idx, kind=None) - default_args = {"output_gradient": grad_arg, "idx": idx_arg} + default_args: dict[str, ast.expr] = { + "output_gradient": grad_arg, + "idx": idx_arg, + } generated_fn, _used_keys = self.create_primitive_call( grad_fn, local_input_keys, global_input_keys, default_args ) diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index e85a0fd8..72b1d82f 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -15,14 +15,21 @@ import ast import importlib import keyword -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable from functools import partial from posixpath import basename, splitext -from typing import Any +from typing import Any, Generic, Literal, Protocol, overload from ...backends.backend import ParallelBackend from ...utils.func_utils import prepare_function_args -from ..common import MainValueType +from ..common import ( + DataEvalType, + DataType, + EvaluateAllType, + EvaluateGradientsType, + EvaluateType, + ParamsEvalType, +) from ..logical import PrimitiveModel from ..physical.model import PhysicalModel from ..utils import GeneratedFunction @@ -36,8 +43,58 @@ FinalCost = "final_cost" -class PythonCodeGen[DataType](CodeGen): - def __init__(self, pm: PhysicalModel[Any]) -> None: +class RawEvaluateType(Protocol, Generic[DataType]): + def __call__( + self, + params: ParamsEvalType[DataType] | None, + data: DataEvalType[DataType] | None, + cache: DataEvalType[DataType] | None, + ) -> DataEvalType[DataType]: ... + + +class RawGradientType(Protocol, Generic[DataType]): + def __call__( + self, + params: ParamsEvalType[DataType], + gradients: ParamsEvalType[DataType], + data: DataEvalType[DataType], + cache: DataEvalType[DataType], + ) -> ParamsEvalType[DataType]: ... + + +class ManualGradWrapperFn(Protocol, Generic[DataType]): + @overload + def __call__( + self, + params: ParamsEvalType[DataType], + data: DataEvalType[DataType], + output_gradients: ParamsEvalType[DataType], + include_output: Literal[True], + ) -> tuple[DataEvalType[DataType], ParamsEvalType[DataType]]: ... + + @overload + def __call__( + self, + params: ParamsEvalType[DataType], + data: DataEvalType[DataType], + output_gradients: ParamsEvalType[DataType], + include_output: Literal[False], + ) -> ParamsEvalType[DataType]: ... + + def __call__( + self, + params: ParamsEvalType[DataType], + data: DataEvalType[DataType], + output_gradients: ParamsEvalType[DataType], + include_output: bool, + ) -> ( + ParamsEvalType[DataType] + | tuple[DataEvalType[DataType], ParamsEvalType[DataType]] + ): ... + + +class PythonCodeGen(CodeGen[Any], Generic[DataType]): + def __init__(self, pm: PhysicalModel[DataType]) -> None: super().__init__(pm) self.module = ast.parse("") @@ -73,7 +130,13 @@ def write_code(self, file_path: str): with open(file_path, "w") as file: file.write(self.code) - def compile_code(self, jit: bool = False): + def compile_code( + self, jit: bool = False + ) -> tuple[ + EvaluateType[DataType], + EvaluateGradientsType[DataType] | None, + EvaluateAllType[DataType] | None, + ]: eval_fn, grad_fn = self.exec_generated_code() return self.post_process_fns(eval_fn, grad_fn, jit) @@ -128,8 +191,15 @@ def exec_generated_code( return eval_fn, grad_fn def post_process_fns( - self, raw_eval_fn: Callable, raw_grad_fn: Callable | None, jit: bool - ): + self, + raw_eval_fn: RawEvaluateType[DataType], + raw_grad_fn: ManualGradWrapperFn[DataType] | None, + jit: bool, + ) -> tuple[ + EvaluateType[DataType], + EvaluateGradientsType[DataType] | None, + EvaluateAllType[DataType] | None, + ]: """In this function going to wrap the raw functions with some additional functionalities. @@ -138,7 +208,7 @@ def post_process_fns( 3. If jit is True, going to compile the functions with jit fn. """ - eval_fn: Callable | partial = partial( + eval_fn: EvaluateType[DataType] | partial[Any] = partial( self.compute_evaluate, fn=raw_eval_fn, cache=self.pm.data_store.data_values, @@ -235,7 +305,7 @@ def get_primitive_details(self, output_key: str): def call_primitive( self, model: PrimitiveModel, - fn: Callable, + fn: Callable[..., Any], l_input_keys: list[str], g_input_keys: list[str], output_key: str, @@ -257,7 +327,7 @@ def generate_evaluate(self): function_body: list[ast.stmt] = [] return_values: list[ast.expr] = [] - used_keys = set() + used_keys: set[str] = set() used_keys |= set(self.pm._flat_graph.output_dict.values()) unused_keys = self.pm.data_store.unused_keys @@ -399,7 +469,7 @@ def append_inputs(self, input_body: list[ast.stmt], key: str, dict_type: str): def create_primitive_call( self, - function: Callable, + function: Callable[..., Any], local_keys: list[str], global_keys: list[str], default_args: dict[str, ast.expr] | None = None, @@ -445,7 +515,7 @@ def create_primitive_call( def create_primitive_call_targets( self, output_key: str, model: PrimitiveModel, inference: bool - ) -> tuple[Sequence[ast.expr | ast.Name], set[str]]: + ) -> tuple[list[ast.expr], set[str]]: if ( keyword.iskeyword(output_key) or output_key in self.pm.backend.primitive_function_dict @@ -454,7 +524,7 @@ def create_primitive_call_targets( else: target_name = output_key - targets = [ + targets: list[ast.expr] = [ ast.Name( id=target_name, ctx=ast.Store(), @@ -472,19 +542,25 @@ def add_partial_function(self, formula_key: str): def compute_evaluate( self, - params: dict[str, DataType] | None = None, - data: dict[str, DataType] | None = None, - cache: dict[str, DataType | MainValueType | str] | None = None, + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + cache: DataEvalType[DataType] | None = None, *, - fn: Callable, - ): + fn: RawEvaluateType[DataType], + ) -> DataEvalType[DataType]: return fn(params, data, cache) def create_gradient_fn( - self, raw_evaluate_fn: Callable, raw_evaluate_grad_fn: Callable | None + self, + # raw_evaluate_fn: RawEvaluateType[DataType], + # raw_evaluate_grad_fn: ManualGradWrapperFn[DataType] | None, + raw_evaluate_fn: RawEvaluateType[DataType], + raw_evaluate_grad_fn: ManualGradWrapperFn[DataType] | None, ): + fn_all: EvaluateAllType[DataType] + grad_fn: EvaluateGradientsType[DataType] if not self.pm.backend.is_manualgrad: - grad_fn = partial( + grad_fn = partial( # type: ignore self.compute_gradients, raw_evaluate_fn=raw_evaluate_fn, cache=self.pm.data_store.data_values, @@ -501,20 +577,48 @@ def create_gradient_fn( else: assert raw_evaluate_grad_fn is not None, "Gradient function is not defined!" - fn_all = partial(raw_evaluate_grad_fn, include_output=True) + fn_all = partial(raw_evaluate_grad_fn, include_output=True) # type: ignore + grad_fn = partial(raw_evaluate_grad_fn, include_output=False) # type: ignore - return raw_evaluate_grad_fn, fn_all + return grad_fn, fn_all + + @overload + def compute_gradients( + self, + params: ParamsEvalType[DataType], + data: DataEvalType[DataType] | None, + output_gradients: ParamsEvalType[DataType] | None, + cache: DataEvalType[DataType] | None, + include_output: Literal[True], + *, + raw_evaluate_fn: RawEvaluateType[DataType], + ) -> tuple[DataEvalType[DataType], ParamsEvalType[DataType]]: ... + + @overload + def compute_gradients( + self, + params: ParamsEvalType[DataType], + data: DataEvalType[DataType] | None, + output_gradients: ParamsEvalType[DataType] | None, + cache: DataEvalType[DataType] | None, + include_output: Literal[False], + *, + raw_evaluate_fn: RawEvaluateType[DataType], + ) -> ParamsEvalType[DataType]: ... def compute_gradients( self, - params: dict[str, DataType], - data: dict[str, DataType | MainValueType] | None = None, - output_gradients: dict[str, DataType] | None = None, - cache: Mapping[str, DataType | MainValueType | str] | None = None, + params: ParamsEvalType[DataType], + data: DataEvalType[DataType] | None = None, + output_gradients: ParamsEvalType[DataType] | None = None, + cache: DataEvalType[DataType] | None = None, include_output: bool = False, *, - raw_evaluate_fn: Callable, - ) -> dict[str, DataType] | tuple[dict[str, DataType], dict[str, DataType]]: + raw_evaluate_fn: RawEvaluateType[DataType], + ) -> ( + tuple[DataEvalType[DataType], ParamsEvalType[DataType]] + | ParamsEvalType[DataType] + ): # Initialize loss output gradients as None. If FinalCost is # contained in the compiled model, initialize its gradient # with ones. If somehow one wants to set it to another gradient @@ -538,9 +642,6 @@ def compute_gradients( "Models with any losses can not take any gradients for other outputs!" ) - if output_gradients is None: - output_gradients = {} - # Sort gradients with output order output_gradients = { key: output_gradients[key] @@ -556,35 +657,38 @@ def compute_gradients( if (key not in total_output_gradients) and key != FinalCost } ) - assert cache is not None - partial_fn: Callable = partial( + partial_fn: Callable[ + [ParamsEvalType[DataType]], + tuple[DataEvalType[DataType], DataEvalType[DataType]], + ] = partial( self.filter_ignored_outputs, data=data, cache=cache, ignore_grad_keys=total_ignore_grad_keys, raw_evaluate_fn=raw_evaluate_fn, ) - output, input_gradients, aux = self.pm.backend.vjp( - partial_fn, params, cotangents=total_output_gradients, has_aux=True + partial_fn, # type: ignore + params, + cotangents=total_output_gradients, + has_aux=True, ) - output |= aux + all_outputs: DataEvalType[DataType] = output | aux - assert not callable(input_gradients) if include_output: - return output, input_gradients + return all_outputs, input_gradients else: return input_gradients def filter_ignored_outputs( self, - params: dict[str, DataType] | None = None, - data: Mapping[str, MainValueType | DataType | str] | None = None, - cache: Mapping[str, MainValueType | DataType | str] | None = None, - ignore_grad_keys=None, + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + cache: DataEvalType[DataType] | None = None, + ignore_grad_keys: set[str] | None = None, *, - raw_evaluate_fn: Callable, - ) -> tuple[dict[str, DataType], dict[str, DataType]]: + raw_evaluate_fn: RawEvaluateType[DataType], + ) -> tuple[ParamsEvalType[DataType], ParamsEvalType[DataType]]: if params is None: params = {} if data is None: @@ -597,7 +701,7 @@ def filter_ignored_outputs( outputs = raw_evaluate_fn(params, data=data, cache=cache) aux = { - key: outputs.pop(key) + key: outputs.pop(key) # type: ignore for key in list(outputs.keys()) if key in ignore_grad_keys } @@ -608,4 +712,4 @@ def filter_ignored_outputs( " at least one of the {list(aux.keys())}" ) - return outputs, aux + return outputs, aux # type: ignore diff --git a/mithril/framework/codegen/torch_gen.py b/mithril/framework/codegen/torch_gen.py index 8b7c4756..be1f3458 100644 --- a/mithril/framework/codegen/torch_gen.py +++ b/mithril/framework/codegen/torch_gen.py @@ -14,6 +14,9 @@ import ast from collections.abc import Callable +from typing import Any + +import torch from ...backends.with_autograd.torch_backend import TorchBackend from ..logical import PrimitiveModel @@ -21,8 +24,8 @@ from .python_gen import PythonCodeGen -class TorchCodeGen(PythonCodeGen): - def __init__(self, pm: PhysicalModel) -> None: +class TorchCodeGen(PythonCodeGen[torch.Tensor]): + def __init__(self, pm: PhysicalModel[torch.Tensor]) -> None: super().__init__(pm) self.is_parallel_defined = False @@ -32,7 +35,7 @@ def __init__(self, pm: PhysicalModel) -> None: def call_primitive( self, model: PrimitiveModel, - fn: Callable, + fn: Callable[..., Any], l_input_keys: list[str], g_input_keys: list[str], output_key: str, @@ -84,4 +87,4 @@ def call_primitive( keywords=[], ) - return ast.Assign(targets, generated_fn), used_keys | _used_keys # type: ignore + return ast.Assign(targets, generated_fn), used_keys | _used_keys diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 34eb829d..d0adf97b 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -21,7 +21,7 @@ from functools import partial, reduce from itertools import combinations, cycle, product, zip_longest from types import EllipsisType, GenericAlias, UnionType -from typing import Any, Literal, TypeVar, overload +from typing import Any, Generic, Literal, Protocol, TypeVar, overload from ..backends.backend import Backend from ..core import ( @@ -214,10 +214,40 @@ class KeyType(Enum): ) -_TensorTypes = int | float | Constant | tuple +_TensorTypes = int | float | Constant TensorValueType = _TensorTypes | tuple["TensorValueType", ...] | list["TensorValueType"] +ParamsEvalType = dict[str, DataType] +DataEvalType = Mapping[str, DataType | MainValueType | str] + + +class EvaluateType(Protocol, Generic[DataType]): + def __call__( + self, + params: ParamsEvalType[DataType] | None, + data: DataEvalType[DataType] | None, + ) -> DataEvalType[DataType]: ... + + +class EvaluateGradientsType(Protocol, Generic[DataType]): + def __call__( + self, + params: ParamsEvalType[DataType] | None, + data: DataEvalType[DataType] | None, + output_gradients: ParamsEvalType[DataType] | None, + ) -> ParamsEvalType[DataType]: ... + + +class EvaluateAllType(Protocol, Generic[DataType]): + def __call__( + self, + params: ParamsEvalType[DataType] | None, + data: DataEvalType[DataType] | None, + output_gradients: ParamsEvalType[DataType] | None, + ) -> tuple[DataEvalType[DataType], ParamsEvalType[DataType]]: ... + + LossKey = "loss" FinalCost = "final_cost" @@ -653,6 +683,7 @@ def match(self, other: Tensor[DataType] | Scalar) -> Updates: valued, non_valued = (self, other) if self.is_valued else (other, self) assert isinstance(valued, Tensor | Scalar) assert not isinstance(valued.value, ToBeDetermined) + assert isinstance(non_valued, type(valued)) updates |= non_valued.set_value(valued.value) if non_valued == other: if isinstance(other, Tensor): diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index 4aedaeb4..18b9dd12 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -439,7 +439,7 @@ def check_index_type_compatibility( def scalar_item_reduce_input_type( output_type: type | UnionType | GenericAlias, input_type: type | UnionType | GenericAlias, - index, + index: int | ToBeDetermined, ): possible_types = [] out_origin: type[list] | type[tuple] | type[UnionType] | None = get_origin( @@ -1731,7 +1731,7 @@ def concat_constraints( def pad_constraints( - output: Tensor, input: Tensor, pad_width: Scalar + output: Tensor[Any], input: Tensor[Any], pad_width: Scalar ) -> ConstrainResultType: updates = Updates() pad_value: tuple[tuple[int, int], ...] | ToBeDetermined = pad_width.value # type: ignore @@ -1740,7 +1740,9 @@ def pad_constraints( assert input_shape is not None assert output_shape is not None - def process_shape(shape, pad_value, forward=True): + def process_shape( + shape: ShapeRepr, pad_value: tuple[tuple[int, int], ...], forward: bool = True + ): prefix: list[Uniadic] = [] root = None suffix: list[Uniadic] = [] @@ -3400,7 +3402,9 @@ def tensor_item_constraints( def tensor_item_constraint_helper( - item_values: tuple | list, input_unis: list[Uniadic] + item_values: tuple[slice | int | EllipsisType | None, ...] + | list[slice | int | None | EllipsisType], + input_unis: list[Uniadic], ) -> list[Uniadic]: # calculates output uniadics based on given item values and # input uniadics. @@ -3410,7 +3414,7 @@ def tensor_item_constraint_helper( # input_unis = [Uniadic(10), Uniadic(5), Uniadic(2)] --> items = [Uniadic(2), # Uniadic(1), Uniadic(1), Uniadic(2)] - items = [] + items: list[Uniadic] = [] idx = 0 for item in item_values: if item is None: diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 118a6f22..e722670d 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -27,8 +27,13 @@ TBD, Connection, ConnectionData, + DataEvalType, + EvaluateAllType, + EvaluateGradientsType, + EvaluateType, IOKey, MainValueType, + ParamsEvalType, Scalar, Table, Tensor, @@ -680,48 +685,15 @@ def _pre_compile( def generate_functions( self, - eval_fn: Callable[ - [dict[str, DataType] | None, Mapping[str, MainValueType | DataType] | None], - Mapping[str, MainValueType | DataType], - ], - grad_fn: Callable[ - [ - dict[str, DataType] | None, - Mapping[str, MainValueType | DataType] | None, - dict[str, DataType] | None, - ], - dict[str, DataType], - ], - eval_all_fn: Callable[ - [ - dict[str, DataType] | None, - Mapping[str, MainValueType | DataType] | None, - dict[str, DataType] | None, - ], - tuple[Mapping[str, MainValueType | DataType], dict[str, DataType]], - ], + eval_fn: EvaluateType[DataType], + grad_fn: EvaluateGradientsType[DataType] | None, + eval_all_fn: EvaluateAllType[DataType] | None, ) -> None: - self._generated_eval_fn: Callable[ - [dict[str, DataType] | None, Mapping[str, MainValueType | DataType] | None], - Mapping[str, MainValueType | DataType], - ] = eval_fn - self._generated_compute_gradients_fn: Callable[ - [ - dict[str, DataType] | None, - Mapping[str, MainValueType | DataType] | None, - dict[str, DataType] | None, - ], - dict[str, DataType], - ] = grad_fn - - self._generated_evaluate_all_fn: Callable[ - [ - dict[str, DataType] | None, - Mapping[str, MainValueType | DataType] | None, - dict[str, DataType] | None, - ], - tuple[Mapping[str, MainValueType | DataType], dict[str, DataType]], - ] = eval_all_fn + self._generated_eval_fn: EvaluateType[DataType] = eval_fn + self._generated_compute_gradients_fn: EvaluateGradientsType[DataType] | None = ( + grad_fn + ) + self._generated_evaluate_all_fn: EvaluateAllType[DataType] | None = eval_all_fn def create_jacobian_fn(self, generated_fn: Callable): # TODO: Fix this method to make it picklable! @@ -1211,9 +1183,9 @@ def _replace_with_primitive( def evaluate( self, - params: dict[str, DataType] | None = None, - data: Mapping[str, DataType | MainValueType] | None = None, - ) -> Mapping[str, MainValueType | DataType]: + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + ) -> DataEvalType[DataType]: if ( isinstance(self.backend, ParallelBackend) and self.backend._parallel_manager is not None @@ -1224,10 +1196,10 @@ def evaluate( def evaluate_gradients( self, - params: dict[str, DataType] | None = None, - data: Mapping[str, DataType | MainValueType] | None = None, - output_gradients: dict[str, DataType] | None = None, - ) -> dict[str, DataType]: + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + output_gradients: ParamsEvalType[DataType] | None = None, + ) -> ParamsEvalType[DataType]: if self.inference: raise NotImplementedError( "Inference mode does not support gradients calculation" @@ -1240,14 +1212,14 @@ def evaluate_gradients( params, data, output_gradients, fn_name="eval_grad_fn" ) else: - return self._generated_compute_gradients_fn(params, data, output_gradients) + return self._generated_compute_gradients_fn(params, data, output_gradients) # type: ignore def evaluate_all( self, - params: dict[str, DataType] | None = None, - data: Mapping[str, DataType | MainValueType] | None = None, - output_gradients: dict[str, DataType] | None = None, - ) -> tuple[Mapping[str, MainValueType | DataType], dict[str, DataType]]: + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + output_gradients: ParamsEvalType[DataType] | None = None, + ) -> tuple[DataEvalType[DataType], ParamsEvalType[DataType]]: if self.inference: raise NotImplementedError( "Inferece mode does not support gradients calculation" @@ -1260,4 +1232,4 @@ def evaluate_all( params, data, output_gradients, fn_name="eval_all_fn" ) else: - return self._generated_evaluate_all_fn(params, data, output_gradients) + return self._generated_evaluate_all_fn(params, data, output_gradients) # type: ignore From 4bf018451cc7ca863d956b6da0b11abae29c5037 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 5 Dec 2024 11:41:07 +0300 Subject: [PATCH 12/26] overload decorator is added to evaluate_gradients_wrapper_manualgrad function --- mithril/framework/codegen/numpy_gen.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index e0748107..0dc592a1 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -16,7 +16,7 @@ import keyword from collections.abc import Callable from functools import partial -from typing import Any +from typing import Any, Literal, overload import numpy as np @@ -102,6 +102,30 @@ def compile_code( eval_fn, grad_fn = self.exec_generated_code() # TODO: Not looks good, and looks over complicated! + + @overload + def evaluate_gradients_wrapper_manualgrad( + params: ParamsEvalType[np.ndarray[Any, Any]] | None, + data: DataEvalType[np.ndarray[Any, Any]] | None, + output_gradients: ParamsEvalType[np.ndarray[Any, Any]] | None, + *, + grad_fn: RawGradientType[np.ndarray[Any, Any]], + include_output: Literal[False], + ) -> DataEvalType[np.ndarray[Any, Any]]: ... + + @overload + def evaluate_gradients_wrapper_manualgrad( + params: ParamsEvalType[np.ndarray[Any, Any]] | None, + data: DataEvalType[np.ndarray[Any, Any]] | None, + output_gradients: ParamsEvalType[np.ndarray[Any, Any]] | None, + *, + grad_fn: RawGradientType[np.ndarray[Any, Any]], + include_output: Literal[True], + ) -> tuple[ + DataEvalType[np.ndarray[Any, Any]], + ParamsEvalType[np.ndarray[Any, Any]], + ]: ... + def evaluate_gradients_wrapper_manualgrad( params: ParamsEvalType[np.ndarray[Any, Any]] | None = None, data: DataEvalType[np.ndarray[Any, Any]] | None = None, From ac7051dd937bed66fac788ed4858ce4b2f7ee8d5 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 11:01:00 +0300 Subject: [PATCH 13/26] updated based on reviews --- mithril/backends/with_manualgrad/numpy_backend/backend.py | 4 ++-- mithril/framework/codegen/python_gen.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 3be576e2..6323dc93 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -38,10 +38,10 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): type = "numpy" - registered_primitives = {} + registered_primitives: dict[str, Callable[..., Any]] = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" - registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} + registered_primitives_grad_fn: dict[str, Callable[..., np.ndarray[Any, Any]]] = {} def __init__(self, device: str = "cpu", precision: int = 32) -> None: self._precision = precision diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 97d1c386..d4152b0e 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -93,9 +93,9 @@ def __call__( ): ... -class PythonCodeGen(CodeGen[Any], Generic[DataType]): +class PythonCodeGen(CodeGen[DataType]): def __init__(self, pm: PhysicalModel[DataType]) -> None: - super().__init__(pm) + super().__init__(pm) # type: ignore self.module = ast.parse("") self.defined_partial_fns: set[str] = set() From cdda8786f8540b281ccf1c9310b3669f2967f390 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 11:28:07 +0300 Subject: [PATCH 14/26] CI pre-commit fix attempt --- mithril/backends/with_manualgrad/numpy_backend/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 6323dc93..7206c3f5 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -41,7 +41,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): registered_primitives: dict[str, Callable[..., Any]] = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" - registered_primitives_grad_fn: dict[str, Callable[..., np.ndarray[Any, Any]]] = {} + registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} def __init__(self, device: str = "cpu", precision: int = 32) -> None: self._precision = precision From 7c3c644f1fd517fd9aed291841458c35ba9eb3dc Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 11:35:09 +0300 Subject: [PATCH 15/26] codegen typehints are fixed --- mithril/backends/with_manualgrad/numpy_backend/backend.py | 2 +- mithril/framework/codegen/python_gen.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 7206c3f5..6323dc93 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -41,7 +41,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): registered_primitives: dict[str, Callable[..., Any]] = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" - registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} + registered_primitives_grad_fn: dict[str, Callable[..., np.ndarray[Any, Any]]] = {} def __init__(self, device: str = "cpu", precision: int = 32) -> None: self._precision = precision diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index d4152b0e..97d1c386 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -93,9 +93,9 @@ def __call__( ): ... -class PythonCodeGen(CodeGen[DataType]): +class PythonCodeGen(CodeGen[Any], Generic[DataType]): def __init__(self, pm: PhysicalModel[DataType]) -> None: - super().__init__(pm) # type: ignore + super().__init__(pm) self.module = ast.parse("") self.defined_partial_fns: set[str] = set() From 8938418addf77508155a871872a88279aea44114 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 11:40:01 +0300 Subject: [PATCH 16/26] codegen typehints are fixed --- mithril/backends/with_manualgrad/numpy_backend/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 6323dc93..7206c3f5 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -41,7 +41,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): registered_primitives: dict[str, Callable[..., Any]] = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" - registered_primitives_grad_fn: dict[str, Callable[..., np.ndarray[Any, Any]]] = {} + registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} def __init__(self, device: str = "cpu", precision: int = 32) -> None: self._precision = precision From 1d9ec33e8b374b8273e33509dedd3f4dd0acffbc Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 11:44:48 +0300 Subject: [PATCH 17/26] codegen typehints are fixed --- mithril/backends/with_manualgrad/numpy_backend/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 7206c3f5..3be576e2 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -38,7 +38,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): type = "numpy" - registered_primitives: dict[str, Callable[..., Any]] = {} + registered_primitives = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} From 47f285beb48881eccba95b1d84b4335fe144c365 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 13:16:50 +0300 Subject: [PATCH 18/26] numpy 2.2.0 support is added --- mithril/backends/with_manualgrad/numpy_backend/ops.py | 8 +++----- tests/scripts/test_all_models.py | 4 ++-- tests/scripts/test_utils.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index 0073d107..c64e28dc 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import copy import logging +import os from collections.abc import Callable, Iterator, Sequence from functools import partial from itertools import combinations_with_replacement @@ -85,7 +85,7 @@ write_into_cache, ) -np._set_promotion_state("legacy") +os.environ["NPY_PROMOTION_STATE"] = "legacy" AxisType = None | int | Sequence[int] @@ -893,9 +893,7 @@ def to_tensor( return np.array(input[0], dtype=get_type(input[0], precision=precision)) -def tensor_to_list( - input: np.ndarray, cache: CacheType | None = None -) -> NestedFloatOrIntOrBoolList: +def tensor_to_list(input: np.ndarray, cache: CacheType | None = None): return input.tolist() diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 5cd5b66e..64982d8b 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -3171,8 +3171,8 @@ def test_groupnorm_2(): model = GroupNorm(4) input = np.arange(160, dtype=np.float32) - input = input.reshape(1, 16, 10, 1) - input = np.broadcast_to(input, (2, 16, 10, 4)) + input = input.reshape((1, 16, 10, 1)) # type: ignore + input = np.broadcast_to(input, (2, 16, 10, 4)) # type: ignore input = np.concatenate([input, 0.5 * input], axis=-1) weight = np.random.randn(1, 16, 1, 1) diff --git a/tests/scripts/test_utils.py b/tests/scripts/test_utils.py index ed04f581..1980bc2d 100644 --- a/tests/scripts/test_utils.py +++ b/tests/scripts/test_utils.py @@ -223,7 +223,7 @@ def dict_to_random(input: dict, random_shapes: dict | None = None): def randomizer( input: list[str | int | bool | list[int]], -) -> list[int] | int | str | bool: +): if len(input) == 0: return [] elif isinstance(val := input[0], bool) or isinstance(val, str): From 1866707449e5177930a18ab4d1ac95b0b4ac550c Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Mon, 9 Dec 2024 13:25:49 +0300 Subject: [PATCH 19/26] updated based on reviews --- mithril/backends/with_manualgrad/numpy_backend/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 3be576e2..fca06f99 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -41,7 +41,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): registered_primitives = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" primitive_grad_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops_grad" - registered_primitives_grad_fn: dict[str, Callable[..., Any]] = {} + registered_primitives_grad_fn: dict[str, Callable[..., np.ndarray[Any, Any]]] = {} def __init__(self, device: str = "cpu", precision: int = 32) -> None: self._precision = precision From 0450bb233d4613e9cb0c4d60bef2c264639b80dc Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 12 Dec 2024 16:15:27 +0300 Subject: [PATCH 20/26] type annotations in backend files are fixed partially --- mithril/backends/backend.py | 34 +- mithril/backends/parallel.py | 12 +- .../with_autograd/common_primitives.py | 44 +- .../with_autograd/jax_backend/backend.py | 66 ++- .../backends/with_autograd/jax_backend/ops.py | 31 +- .../with_autograd/jax_backend/parallel.py | 7 +- .../with_autograd/jax_backend/utils.py | 66 ++- .../with_autograd/mlx_backend/backend.py | 30 +- .../backends/with_autograd/mlx_backend/ops.py | 26 +- .../with_autograd/mlx_backend/utils.py | 77 ++- .../with_autograd/torch_backend/backend.py | 45 +- .../with_autograd/torch_backend/ops.py | 30 +- .../with_autograd/torch_backend/parallel.py | 58 +- .../with_autograd/torch_backend/utils.py | 123 ++-- .../with_manualgrad/c_backend/backend.py | 16 +- .../with_manualgrad/c_backend/src/array.py | 3 +- .../with_manualgrad/c_backend/src/array.pyi | 30 +- .../with_manualgrad/c_backend/src/utils.py | 6 +- .../with_manualgrad/common_primitives.py | 6 +- .../with_manualgrad/numpy_backend/backend.py | 23 +- .../with_manualgrad/numpy_backend/ops.py | 406 +++++++------ .../with_manualgrad/numpy_backend/ops_grad.py | 547 ++++++++---------- .../with_manualgrad/numpy_backend/utils.py | 194 +++---- mithril/core.py | 4 +- mithril/utils/utils.py | 4 +- tests/scripts/test_parallel.py | 6 +- tests/scripts/test_scripts.py | 1 + 27 files changed, 1039 insertions(+), 856 deletions(-) diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index 5ca12be5..a6f672bb 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -36,7 +36,7 @@ class Backend(ABC, Generic[DataType]): device_type = None supported_precisions = [16, 32, 64] is_installed = True - _device: str + _device: Any _precision: int primitive_function_dict: dict[str, Callable[..., DataType | Any]] registered_primitives: dict[str, Callable[..., DataType]] @@ -65,7 +65,7 @@ def device(self): return self._device @property - def inf(self): + def inf(self) -> DataType | float: raise NotImplementedError("inf is not implemented") @property @@ -84,7 +84,7 @@ def get_backend_array_type(self): # noqa: B902 raise NotImplementedError("get_backend_array_type is not implemented") @staticmethod - def register_primitive(fn: Callable) -> None: + def register_primitive(fn: Callable[..., Any]) -> None: raise NotImplementedError("register_primitive is not implemented!") @abstractmethod @@ -93,10 +93,12 @@ def set_seed(self, seed: int): "set_seed function must be overriden for every backend individually!" ) - def to_device(self, data: DataType, device: str, asynchronous: bool = True): + def to_device( + self, data: DataType, device: str, asynchronous: bool = True + ) -> DataType: raise RuntimeError("Backend does not support to_device method!") - def block_until_ready(self, data: DataType): + def block_until_ready(self, data: DataType) -> DataType | None: raise RuntimeError("Backend does not support block_until_ready method!") def empty_cache(self): # noqa: B027 @@ -316,7 +318,7 @@ def ones( raise NotImplementedError("ones is not implemented!") def ones_like( - self, array: DataType, *, dtype: core.Dtype | None = None + self, input: DataType, *, dtype: core.Dtype | None = None ) -> DataType: """Returns a new backend array filled with ones, with the same size, same dtype and same device with the given array. @@ -337,7 +339,7 @@ def ones_like( raise NotImplementedError("ones_like is not implemented!") def zeros_like( - self, array: DataType, *, dtype: core.Dtype | None = None + self, input: DataType, *, dtype: core.Dtype | None = None ) -> DataType: """Returns a new backend array filled with zeros, with the same size, same dtype and same device with the given array. @@ -588,7 +590,7 @@ def softplus(self, input: DataType) -> DataType: """ raise NotImplementedError("softplus is not implemented!") - def stop_gradient(self, data: DataType) -> DataType: + def stop_gradient(self, input: DataType) -> DataType: """ Stop the gradient computation for the given data. @@ -677,7 +679,7 @@ def expand_dims(self, input: DataType, axis: int) -> DataType: """ raise NotImplementedError("expand_dims is not implemented!") - def stack(self, arrays: list[DataType], axis: int = 0) -> DataType: + def stack(self, inputs: list[DataType], axis: int = 0) -> DataType: """ Stack a sequence of arrays along a new axis. @@ -693,7 +695,7 @@ def stack(self, arrays: list[DataType], axis: int = 0) -> DataType: """ raise NotImplementedError("stack is not implemented!") - def cat(self, arrays: list[DataType], axis: int = 0) -> DataType: + def cat(self, inputs: list[DataType], axis: int = 0) -> DataType: """ Concatenate a sequence of arrays along an existing axis. @@ -814,12 +816,12 @@ def any(self, input: DataType) -> DataType: raise NotImplementedError("any is not implemented!") def transpose( - self, data: DataType, axes: tuple[int, ...] | list[int] | None + self, input: DataType, axes: tuple[int, ...] | list[int] | None ) -> DataType: raise NotImplementedError() def unique( - self, input: DataType, **kwargs + self, input: DataType, **kwargs: Any ) -> tuple[DataType, DataType | None, DataType | None]: raise NotImplementedError("unique is not implemented!") @@ -830,11 +832,7 @@ def where(self, cond: DataType, input1: DataType, input2: DataType) -> DataType: raise NotImplementedError("where is not implemented!") def multinomial( - self, - probs: DataType, - num_samples: int, - replacement: bool = False, - **kwargs, + self, probs: DataType, num_samples: int, replacement: bool = False ) -> DataType: raise NotImplementedError("multinomial is not implemented!") @@ -1378,7 +1376,7 @@ def _register_callable( def _run_callable(self, *primals, fn_name: str): raise NotImplementedError() - def _create_parallel(self, device_mesh: tuple[int, ...]) -> Parallel: + def _create_parallel(self, device_mesh: tuple[int, ...]): raise NotImplementedError( f"{self.type.capitalize()} backend does not support parallelization!" ) diff --git a/mithril/backends/parallel.py b/mithril/backends/parallel.py index 1c520a5f..ea8451a5 100644 --- a/mithril/backends/parallel.py +++ b/mithril/backends/parallel.py @@ -14,15 +14,15 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Generic +from typing import Any, Generic from ..core import DataType class Parallel(ABC, Generic[DataType]): - def __init__(self, n_devices) -> None: + def __init__(self, n_devices: int) -> None: self.n_devices = n_devices - self.callables: dict[str, Callable] = {} + self.callables: dict[str, Callable[..., Any]] = {} if self.n_devices <= 1: raise ValueError( @@ -31,11 +31,13 @@ def __init__(self, n_devices) -> None: ) @abstractmethod - def run_callable(self, *primals, fn_name: str): + def run_callable(self, *primals: Any, fn_name: str) -> dict[str, Any]: raise NotImplementedError() @abstractmethod - def parallelize(self, tensor: DataType, device_mesh: tuple[int, ...] | None = None): + def parallelize( + self, tensor: DataType, device_mesh: tuple[int, ...] | None = None + ) -> dict[str, Any]: raise NotImplementedError() def clean_up(self): diff --git a/mithril/backends/with_autograd/common_primitives.py b/mithril/backends/with_autograd/common_primitives.py index 31af995f..afb7f22c 100644 --- a/mithril/backends/with_autograd/common_primitives.py +++ b/mithril/backends/with_autograd/common_primitives.py @@ -139,11 +139,13 @@ def squared_error(input: DataType, target: DataType): return (input - target) ** 2 -def minus(input: DataType): +def minus(input: DataType) -> DataType: return -input -def transpose(input: DataType, axes: tuple[int, ...] | list[int] | None = None): +def transpose( + input: DataType, axes: tuple[int, ...] | list[int] | None = None +) -> DataType: if not axes: return input.T return input.transpose(*axes) @@ -167,11 +169,11 @@ def buffer(input: DataType): return input -def permute_tensor(input: DataType, indices: DataType): +def permute_tensor(input: DataType, indices: DataType) -> DataType: return input[indices] # type: ignore -def reshape(input: DataType, shape: tuple[int, ...]): +def reshape(input: DataType, shape: tuple[int, ...]) -> DataType: return input.reshape(shape) @@ -191,7 +193,7 @@ def cartesian_diff(left: DataType, right: DataType): return left[:, None, :] - right[None, :, :] -def primitive_embedding(input: DataType, embedding_matrix: DataType): +def primitive_embedding(input: DataType, embedding_matrix: DataType) -> DataType: return embedding_matrix[input] # type: ignore @@ -226,7 +228,10 @@ def to_list(*args: tuple[int | float | bool, ...]): return list(args) -def padding_converter_1d(input, kernel_size): +def padding_converter_1d( + input: PaddingType | int | Sequence[int], kernel_size: tuple[int, int] | int +) -> tuple[int, int]: + output: tuple[int, int] if isinstance(input, PaddingType): if input == PaddingType.VALID: output = (0, 0) @@ -243,15 +248,18 @@ def padding_converter_1d(input, kernel_size): elif isinstance(input, int): output = (input, input) - elif isinstance(input, Sequence): + else: if isinstance(input[0], Sequence) or isinstance(input[1], Sequence): raise RuntimeError(f"Given input '{input}' is not valid!") - output = tuple(input) + output = (input[0], input[1]) return output -def padding_converter_2d(input, kernel_size): +def padding_converter_2d( + input: PaddingType | int | Sequence[int] | Sequence[Sequence[int]], + kernel_size: tuple[int, int] | int, +) -> tuple[int, int] | tuple[tuple[int, int], tuple[int, int]]: output: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] if isinstance(input, PaddingType): if input == PaddingType.VALID: @@ -262,18 +270,16 @@ def padding_converter_2d(input, kernel_size): "'same' padding is not supported when the kernel size is even!" ) output = (kernel_size[0] // 2, kernel_size[1] // 2) - elif isinstance(kernel_size, int): + else: if kernel_size % 2 == 0: raise RuntimeError( "'same' padding is not supported when the kernel size is even!" ) half = kernel_size // 2 output = ((half, half), (half, half)) - else: - raise RuntimeError("Kernel size must be 'tuple[int, int]' or 'int'!") elif isinstance(input, int): output = (input, input) - elif isinstance(input, Sequence): + else: if isinstance(input[0], int) and isinstance(input[1], int): output = (input[0], input[1]) elif isinstance(input[0], Sequence) and isinstance(input[1], Sequence): @@ -284,14 +290,22 @@ def padding_converter_2d(input, kernel_size): return output -def stride_converter(input, kernel_size): +def stride_converter( + input: int | PaddingType | tuple[int, int] | None, + kernel_size: int | tuple[int, int], +): if input is None: return kernel_size else: return input -def tuple_converter(input): +def tuple_converter( + input: int + | PaddingType + | tuple[int, int] + | tuple[tuple[int, int], tuple[int, int]], +): if isinstance(input, int): return (input, input) else: diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index e3433282..b59ccf51 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -45,7 +45,7 @@ class JaxBackend(ParallelBackend[jax.numpy.ndarray]): """ type = "jax" - registered_primitives = {} + registered_primitives: dict[str, Callable[..., jax.numpy.ndarray]] = {} primitive_fn_path = "mithril.backends.with_autograd.jax_backend.ops" def __init__( @@ -107,7 +107,7 @@ def get_available_devices(): return utils.get_available_devices() @staticmethod - def register_primitive(fn: Callable) -> None: + def register_primitive(fn: Callable[..., Any]) -> None: JaxBackend.registered_primitives[fn.__name__] = fn def set_seed(self, seed: int): @@ -131,7 +131,7 @@ def to_device( return jax.device_put(data, device=_device).block_until_ready() return jax.device_put(data, device=_device) - def block_until_ready(self, data: jax.Array): + def block_until_ready(self, data: jax.Array) -> jax.Array | None: """Block until the specified data is ready. Parameters @@ -139,9 +139,11 @@ def block_until_ready(self, data: jax.Array): data: jax.Array The data for which the method will block until it is ready. """ - data.block_until_ready() + return data.block_until_ready() - def _creation_fn_wrapper(self, fn: Callable) -> Callable: + def _creation_fn_wrapper( + self, fn: Callable[..., jax.Array] + ) -> Callable[..., jax.Array]: """ Wrapper for array creation functions. @@ -170,7 +172,9 @@ def _creation_fn_wrapper(self, fn: Callable) -> Callable: return array_conversion_fn - def _conversion_fn_wrapper(self, fn: Callable) -> Callable: + def _conversion_fn_wrapper( + self, fn: Callable[..., jax.Array] + ) -> Callable[..., jax.Array]: """ Wrapper for array conversion functions. @@ -202,7 +206,13 @@ def _conversion_fn_wrapper(self, fn: Callable) -> Callable: return array_conversion_fn - def _parallelize(self, *args, fn: Callable, device_mesh, **kwargs) -> jax.Array: + def _parallelize( + self, + *args: Any, + fn: Callable[..., jax.Array], + device_mesh: tuple[int, ...], + **kwargs: Any, + ) -> jax.Array: """ Parallelizes the function's return tensor across devices. @@ -227,7 +237,7 @@ def _parallelize(self, *args, fn: Callable, device_mesh, **kwargs) -> jax.Array: return self._parallel_manager.parallelize(tensor, device_mesh) def _register_callable( - self, fn: Callable | partial, fn_name: str, jit: bool = False + self, fn: Callable[..., Any], fn_name: str, jit: bool = False ): assert ( self._parallel_manager is not None @@ -236,7 +246,7 @@ def _register_callable( fn_name = str(id(self)) + fn_name return self._parallel_manager.register_callable(fn, fn_name, jit) - def _run_callable(self, *primals, fn_name: str): + def _run_callable(self, *primals: jax.Array, fn_name: str): assert ( self._parallel_manager is not None ), "Parallel manager is not initialized!" @@ -441,32 +451,32 @@ def flatten( ) -> jax.Array: return ops.flatten(input, start_dim=start_dim, end_dim=end_dim) - def abs(self, array: jax.Array) -> jax.Array: - return jax.numpy.abs(array) + def abs(self, input: jax.Array) -> jax.Array: + return jax.numpy.abs(input) def sign(self, input: jax.Array) -> jax.Array: return jax.numpy.sign(input) - def sin(self, array: jax.Array) -> jax.Array: - return jax.numpy.sin(array) + def sin(self, input: jax.Array) -> jax.Array: + return jax.numpy.sin(input) - def cos(self, array: jax.Array) -> jax.Array: - return jax.numpy.cos(array) + def cos(self, input: jax.Array) -> jax.Array: + return jax.numpy.cos(input) - def tanh(self, array: jax.Array) -> jax.Array: - return jax.nn.tanh(array) + def tanh(self, input: jax.Array) -> jax.Array: + return jax.nn.tanh(input) - def relu(self, array: jax.Array) -> jax.Array: - return jax.nn.relu(array) + def relu(self, input: jax.Array) -> jax.Array: + return jax.nn.relu(input) - def leaky_relu(self, array: jax.Array, slope: float | jax.Array) -> jax.Array: - return jax.nn.leaky_relu(array, slope) + def leaky_relu(self, input: jax.Array, slope: float | jax.Array) -> jax.Array: + return jax.nn.leaky_relu(input, slope) - def sigmoid(self, array: jax.Array) -> jax.Array: - return jax.nn.sigmoid(array) + def sigmoid(self, input: jax.Array) -> jax.Array: + return jax.nn.sigmoid(input) - def softplus(self, array: jax.Array) -> jax.Array: - return jax.nn.softplus(array) + def softplus(self, input: jax.Array) -> jax.Array: + return jax.nn.softplus(input) def softmax(self, input: jax.Array, dim: int = -1) -> jax.Array: # TODO: dim can be Sequence[int] as well. Should work @@ -535,7 +545,7 @@ def transpose( return data.transpose(axes) def unique( - self, input, **kwargs + self, input: jax.Array, **kwargs: Any ) -> tuple[jax.Array, jax.Array | None, jax.Array | None]: return jax.numpy.unique(input, **kwargs) @@ -546,7 +556,7 @@ def topk(self, input: jax.Array, k: int) -> jax.Array: return jax.lax.top_k(input, k)[0] def multinomial( - self, probs: jax.Array, num_samples: int, replacement: bool = False, **kwargs + self, probs: jax.Array, num_samples: int, replacement: bool = False ) -> jax.Array: """ Faster JAX implementation of multinomial sampling. @@ -598,7 +608,7 @@ def multinomial( return samples - def jit(self, *args, **kwargs): + def jit(self, *args: Any, **kwargs: Any): return jax.jit(*args, **kwargs) def grad( diff --git a/mithril/backends/with_autograd/jax_backend/ops.py b/mithril/backends/with_autograd/jax_backend/ops.py index 094e71e6..26469262 100644 --- a/mithril/backends/with_autograd/jax_backend/ops.py +++ b/mithril/backends/with_autograd/jax_backend/ops.py @@ -15,6 +15,7 @@ from collections.abc import Callable, Iterator, Sequence from functools import partial from itertools import combinations_with_replacement +from typing import Any import jax import jax.numpy as jnp @@ -309,7 +310,7 @@ def relu(input: jax.Array) -> jax.Array: return functionals.relu(input) -def leaky_relu(input: jax.Array, slope: jax.Array): +def leaky_relu(input: jax.Array, slope: jax.Array) -> jax.Array: return functionals.leaky_relu(input, slope) @@ -617,7 +618,9 @@ def cross_entropy( categorical: bool = True, robust: bool = False, ) -> jax.Array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else jnp.log + log: partial[jax.Array] | Callable[..., jax.Array] = ( + partial(robust_log, cutoff=cutoff) if robust else jnp.log + ) _weights = calculate_cross_entropy_class_weights( input, target, categorical, weights ) @@ -644,7 +647,9 @@ def cross_entropy_with_logits( categorical: bool = True, robust: bool = False, ) -> jax.Array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else jnp.log + log: partial[jax.Array] | Callable[..., jax.Array] = ( + partial(robust_log, cutoff=cutoff) if robust else jnp.log + ) _weights = calculate_cross_entropy_class_weights( input, target, categorical, weights ) @@ -692,7 +697,9 @@ def binary_cross_entropy( pos_weight: bool | float = 1.0, robust: bool = False, ) -> jax.Array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else jnp.log + log: partial[jax.Array] | Callable[..., jax.Array] = ( + partial(robust_log, cutoff=cutoff) if robust else jnp.log + ) _pos_weight: jax.Array | float if isinstance(pos_weight, bool): @@ -711,7 +718,9 @@ def binary_cross_entropy_with_logits( pos_weight: float | bool = 1.0, robust: bool = False, ) -> jax.Array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else jnp.log + log: partial[jax.Array] | Callable[..., jax.Array] = ( + partial(robust_log, cutoff=cutoff) if robust else jnp.log + ) _pos_weight: jax.Array | float if isinstance(pos_weight, bool): @@ -811,10 +820,8 @@ def size(input: jax.Array, dim: int | tuple[int, ...] | None) -> int | tuple[int return input.size if isinstance(dim, int): return input.shape[dim] - if isinstance(dim, tuple): - return tuple(input.shape[idx] for idx in dim) else: - raise ValueError(f"Unexpected dim: {dim}") + return tuple(input.shape[idx] for idx in dim) def flatten(input: jax.Array, *, start_dim: int = 0, end_dim: int = -1) -> jax.Array: @@ -858,7 +865,7 @@ def tensor_to_list(input: jax.Array) -> NestedFloatOrIntOrBoolList: return input.tolist() -def arange(*args, device: str, precision: int) -> jax.Array: +def arange(*args: Any, device: str, precision: int) -> jax.Array: with jax.default_device(get_device(device)): return handle_data_precision(jnp.arange(*args), precision) @@ -967,16 +974,16 @@ def gpr_v_outer(K: jax.Array, K_term: jax.Array, L: jax.Array) -> jax.Array: return v_outer -def isnan(input): +def isnan(input: jax.Array) -> jax.Array: return jnp.isnan(input) def nan_to_num( - input, + input: jax.Array, nan: int | float | None, posinf: int | float | None, neginf: int | float | None, -): +) -> jax.Array: return jnp.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf) # type: ignore diff --git a/mithril/backends/with_autograd/jax_backend/parallel.py b/mithril/backends/with_autograd/jax_backend/parallel.py index 1e2efc81..2ae2c9bd 100644 --- a/mithril/backends/with_autograd/jax_backend/parallel.py +++ b/mithril/backends/with_autograd/jax_backend/parallel.py @@ -14,6 +14,7 @@ import math from collections.abc import Callable +from typing import Any import jax from jax.experimental import mesh_utils @@ -31,7 +32,7 @@ def __init__(self, n_devices: int, device: str) -> None: ) super().__init__(n_devices) - def run_callable(self, *primals, fn_name: str): + def run_callable(self, *primals: jax.Array, fn_name: str): return self.callables[fn_name](*primals) def parallelize( @@ -42,7 +43,7 @@ def parallelize( # transform user provided device mesh to the one that satisfies the condition, # and replicate the dimensions that are provided as 1 in the device mesh. - replicate_dims = [] + replicate_dims: list[int] = [] _device_mesh = [1] * tensor.ndim if device_mesh is None else list(device_mesh) @@ -66,7 +67,7 @@ def parallelize( return jax.device_put(tensor, sharding) - def register_callable(self, fn: Callable, fn_name: str, jit: bool): + def register_callable(self, fn: Callable[..., Any], fn_name: str, jit: bool): if jit: fn = jax.jit(fn) diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index 3725a179..002d1b36 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -24,7 +24,7 @@ ArrayType = jax.Array -dtype_map = { +dtype_map: dict[str | None, Any] = { "int16": jnp.int16, "int32": jnp.int32, "int": jnp.int32, @@ -51,7 +51,7 @@ def broadcast_to_highest( ) -def vmapper(func: Callable, count: int) -> Callable: +def vmapper(func: Callable[..., jax.Array], count: int) -> Callable[..., jax.Array]: for _ in range(count): func = jax.vmap(func) return func @@ -77,7 +77,7 @@ def robust_power_above_threshold( def robust_power_helper( input1: jax.Array, input2: jax.Array, threshold: jax.Array ) -> jax.Array: - def cond_fun(cond, input1, input2): + def cond_fun(cond: jax.Array, input1: jax.Array, input2: jax.Array) -> jax.Array: return jax.lax.cond( cond, robust_power_under_threshold, @@ -99,7 +99,7 @@ def cond_fun(cond, input1, input2): def robust_log_helper(input1: jax.Array, threshold: jax.Array) -> jax.Array: - def cond_fun(cond, input1): + def cond_fun(cond: jax.Array, input1: jax.Array): return jax.lax.cond( cond, lambda x: jnp.log(threshold) + (jnp.abs(x) / threshold) - 1.0, @@ -112,7 +112,7 @@ def cond_fun(cond, input1): def stable_reciprocal_helper(input1: jax.Array, threshold: jax.Array) -> jax.Array: - def cond_fun(cond, input1): + def cond_fun(cond: jax.Array, input1: jax.Array): return jax.lax.cond( cond, lambda x: -x / jnp.square(threshold) @@ -126,7 +126,7 @@ def cond_fun(cond, input1): def robust_sqrt_helper(input1: jax.Array, threshold: jax.Array) -> jax.Array: - def cond_fun(cond, input1): + def cond_fun(cond: jax.Array, input1: jax.Array): return jax.lax.cond( cond, lambda x: jnp.abs(x) * jnp.reciprocal(jnp.sqrt(threshold)), @@ -181,7 +181,7 @@ def tsne_softmax( def calc_prob_matrix( - negative_dist_sq: jax.Array, sigmas: jax.Array, zero_index=None + negative_dist_sq: jax.Array, sigmas: jax.Array, zero_index: int | None = None ) -> jax.Array: """Convert a distances matrix to a matrix of probabilities. Parameters @@ -254,12 +254,12 @@ def find_optimal_sigmas( jax.Array Returns optimal sigma values. """ - sigmas = [] + sigmas: list[float] = [] # For each row of the matrix (each point in our dataset) for i in range(negative_dist_sq.shape[0]): # Make fn that returns perplexity of this row given sigma - def eval_fn(sigma): + def eval_fn(sigma: float) -> jax.Array: return perplexity_fn(negative_dist_sq[i, :], jnp.array(sigma), i, threshold) # noqa: B023 # Binary search over sigmas to achieve target perplexity @@ -270,7 +270,7 @@ def eval_fn(sigma): return jnp.array(sigmas, dtype=negative_dist_sq.dtype) -def polynomial_features_helper(x, y): +def polynomial_features_helper(x: jax.Array, y: jax.Array) -> jax.Array: # NOTE: This helper function is used to handle (0.0 ** 0) case. JAX original # power function gradient returns NAN for this case but we want the gradient # return 0.0 without changing forward characteristics for this point. @@ -285,7 +285,7 @@ def polynomial_features_helper(x, y): def get_available_devices(): - backends = set(jax._src.xla_bridge.backends()) - set(["interpreter"]) # type: ignore + backends: set[str] = set(jax._src.xla_bridge.backends()) - set(["interpreter"]) devices = [ f"{backend.replace('METAL','mps')}:{idx}" for backend in list(backends) @@ -313,7 +313,7 @@ def get_device(device: str): def _get_available_backends() -> list[str]: - backends = set(jax._src.xla_bridge.backends()) - set(["interpreter"]) # type: ignore + backends: set[str] = set(jax._src.xla_bridge.backends()) - set(["interpreter"]) return list(backends) @@ -333,7 +333,7 @@ def _parse_device_string(device: str): return backend, device_idx -def handle_dtype(dtype: Any) -> Any: +def handle_dtype(dtype: str | core.Dtype | jnp.dtype[Any]) -> jnp.dtype[Any]: if isinstance(dtype, core.Dtype): return dtype_map[dtype.name] elif isinstance(dtype, str) and dtype in dtype_map: @@ -346,10 +346,14 @@ def handle_dtype(dtype: Any) -> Any: def creation_fn_wrapper( - *args, fn: Callable, dtype=None, device: str, precision: int, **kwargs + *args: Any, + fn: Callable[..., jax.Array], + dtype: core.Dtype | jnp.dtype[Any] | None = None, + device: str, + precision: int, + **kwargs: Any, ): - if isinstance(device, str): - _device = get_device(device) + _device = get_device(device) if dtype is not None: dtype = handle_dtype(dtype) @@ -363,10 +367,15 @@ def creation_fn_wrapper( def conversion_fn_wrapper( - data, *args, fn: Callable, device: str, precision: int, dtype=None, **kwargs + data: Any, + *args: Any, + fn: Callable[..., jax.Array], + device: str, + precision: int, + dtype: core.Dtype | jnp.dtype[Any] | None = None, + **kwargs: Any, ): - if isinstance(device, str): - _device = get_device(device) + _device = get_device(device) if dtype is not None: dtype = handle_dtype(dtype) @@ -400,15 +409,14 @@ def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: def handle_data_dtype(data: jax.Array, dtype: core.Dtype | int) -> jax.Array: - if isinstance(dtype, int): - dtype = core.Dtype(dtype) + dtype = core.Dtype(dtype) if data.dtype != dtype_map[dtype.name]: return data.astype(dtype_map[dtype.name]) return data -def get_type(input: int | float | bool | Sequence, precision: int): +def get_type(input: int | float | bool | Sequence[Any], precision: int): type = find_dominant_type(input).__name__ if type == "bool": return jax.numpy.bool_ @@ -416,7 +424,7 @@ def get_type(input: int | float | bool | Sequence, precision: int): return getattr(jax.numpy, type + str(precision)) -def calculate_tpr_fpr(threshold, input, label): +def calculate_tpr_fpr(threshold: jax.Array, input: jax.Array, label: jax.Array): input_c = input.copy() n_positive = (label == 1).sum() @@ -431,7 +439,7 @@ def calculate_tpr_fpr(threshold, input, label): return tpr, fpr -def log_sigmoid(input: jax.Array, log: Callable, robust: bool): +def log_sigmoid(input: jax.Array, log: Callable[..., jax.Array], robust: bool): min = jnp.minimum(0, input) input = jnp.exp(-jnp.abs(input)) if not robust: @@ -439,22 +447,24 @@ def log_sigmoid(input: jax.Array, log: Callable, robust: bool): return min - log(1 + input) -def log_softmax(input: jax.Array, log: Callable, robust: bool, axis: int = -1): +def log_softmax( + input: jax.Array, log: Callable[..., jax.Array], robust: bool, axis: int = -1 +) -> jax.Array: if not robust: return jax.nn.log_softmax(input, axis) return input - log(jnp.exp(input).sum(axis=axis, keepdims=True)) -def calculate_binary_class_weight(labels) -> jax.Array: +def calculate_binary_class_weight(labels: jax.Array) -> jax.Array: return (1 - labels.mean()) / labels.mean() -def calculate_categorical_class_weight(labels, num_classes: int): +def calculate_categorical_class_weight(labels: jax.Array, num_classes: int): one_hot = jnp.eye(num_classes)[labels] return calculate_class_weight(one_hot) -def calculate_class_weight(labels): +def calculate_class_weight(labels: jax.Array) -> jax.Array: # Expected shape (N, C, ...) or (C) return ( (1 / labels.sum(axis=tuple(i for i in range(labels.ndim) if i != 1))) diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index cad491fd..ba5e858b 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -31,7 +31,7 @@ class MlxBackend(Backend[mx.array]): type = "mlx" supported_precisions = [16, 32] - registered_primitives = {} + registered_primitives: dict[str, Callable[..., mx.array]] = {} primitive_fn_path = "mithril.backends.with_autograd.mlx_backend.ops" def __init__( @@ -78,7 +78,7 @@ def get_available_devices(): return utils.get_available_devices() @staticmethod - def register_primitive(fn: Callable) -> None: + def register_primitive(fn: Callable[..., mx.array]) -> None: MlxBackend.registered_primitives[fn.__name__] = fn def set_seed(self, seed: int): @@ -93,14 +93,18 @@ def to_device( def block_until_ready(self, data: mx.array): mx.eval(data) - def _creation_fn_wrapper(self, fn: Callable) -> Callable: + def _creation_fn_wrapper( + self, fn: Callable[..., mx.array] + ) -> Callable[..., mx.array]: return partial( utils.creation_fn_wrapper, fn=fn, precision=self.precision, ) - def _conversion_fn_wrapper(self, fn: Callable) -> Callable: + def _conversion_fn_wrapper( + self, fn: Callable[..., mx.array] + ) -> Callable[..., mx.array]: return partial( utils.conversion_fn_wrapper, fn=fn, @@ -289,7 +293,7 @@ def rand_uniform( low=low, high=high, shape=_shape, dtype=utils.dtype_map[_dtype] ) - def arange(self, *args, dtype: Dtype | None = None) -> mx.array: + def arange(self, *args: float | int, dtype: Dtype | None = None) -> mx.array: _dtype: str | None = None if isinstance(dtype, Dtype): _dtype = dtype.name @@ -354,8 +358,8 @@ def log(self, input: mx.array) -> mx.array: def isnan(self, input: mx.array) -> mx.array: return mx.isnan(input) - def stop_gradient(self, data: mx.array) -> mx.array: - return mx.stop_gradient(data) + def stop_gradient(self, input: mx.array) -> mx.array: + return mx.stop_gradient(input) def squeeze(self, input: mx.array) -> mx.array: return mx.squeeze(input) @@ -420,7 +424,7 @@ def topk(self, input: mx.array, k: int) -> mx.array: return -mx.sort(-mx.topk(input, k)) def multinomial( - self, probs: mx.array, num_samples: int, replacement: bool = False, **kwargs + self, probs: mx.array, num_samples: int, replacement: bool = False ) -> mx.array: """ MLX implementation matching torch.multinomial behavior. @@ -530,10 +534,10 @@ def multinomial( def jit(self, fn: Callable[..., Any]) -> Callable[..., Any]: return fn - def grad(self, fn: Callable) -> Callable: + def grad(self, fn: Callable[..., mx.array]) -> Callable[..., mx.array]: return mx.grad(fn) - def value_and_grad(self, fn: Callable) -> Callable: + def value_and_grad(self, fn: Callable[..., mx.array]) -> Callable: return mx.value_and_grad(fn) @overload @@ -646,7 +650,7 @@ def vjp( _cotangents = cotangents if isinstance(cotangents, mx.array): _cotangents = [cotangents] - elif isinstance(cotangents, tuple): + else: _cotangents = list(cotangents) # Calculate VJP. out_list, vjp_list = mx.vjp(_fn, _primals, _cotangents) @@ -683,5 +687,7 @@ def vjp( return output, vjp, aux - def vmap(self, fn: Callable) -> Callable: + def vmap( + self, fn: Callable[[mx.array], mx.array] + ) -> Callable[[mx.array], mx.array]: return mx.vmap(fn) diff --git a/mithril/backends/with_autograd/mlx_backend/ops.py b/mithril/backends/with_autograd/mlx_backend/ops.py index f5e6271a..1fb13159 100644 --- a/mithril/backends/with_autograd/mlx_backend/ops.py +++ b/mithril/backends/with_autograd/mlx_backend/ops.py @@ -372,8 +372,7 @@ def conv1d( padding: tuple[int, int] = (1, 1), dilation: int = 1, ) -> mx.array: - if isinstance(padding, Sequence): - input = mx.pad(input, [(0, 0), (0, 0), (padding[0], padding[1])]) + input = mx.pad(input, [(0, 0), (0, 0), (padding[0], padding[1])]) # Channel first -> Channel last input = mx.swapaxes(input, -2, -1) @@ -567,7 +566,9 @@ def cross_entropy( categorical: bool = True, robust: bool = False, ) -> mx.array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else mx.log + log: partial[mx.array] | Callable[..., mx.array] = ( + partial(robust_log, cutoff=cutoff) if robust else mx.log + ) _weights = utils.calculate_cross_entropy_class_weights( input, target, categorical, weights ) @@ -592,7 +593,9 @@ def cross_entropy_with_logits( categorical: bool = True, robust: bool = False, ) -> mx.array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else mx.log + log: partial[mx.array] | Callable[..., mx.array] = ( + partial(robust_log, cutoff=cutoff) if robust else mx.log + ) _weights = utils.calculate_cross_entropy_class_weights( input, target, categorical, weights ) @@ -642,7 +645,10 @@ def binary_cross_entropy( pos_weight: bool | float = 1.0, robust: bool = False, ) -> mx.array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else mx.log + log: partial[mx.array] | Callable[..., mx.array] = ( + partial(robust_log, cutoff=cutoff) if robust else mx.log + ) + _pos_weight: mx.array | float | bool if isinstance(pos_weight, bool) and pos_weight: _pos_weight = utils.calculate_binary_class_weight(target) else: @@ -659,7 +665,9 @@ def binary_cross_entropy_with_logits( pos_weight: bool | float = 1.0, robust: bool = False, ) -> mx.array: - log: partial | Callable = partial(robust_log, cutoff=cutoff) if robust else mx.log + log: partial[mx.array] | Callable[..., mx.array] = ( + partial(robust_log, cutoff=cutoff) if robust else mx.log + ) _pos_weight: mx.array | float if isinstance(pos_weight, bool): @@ -756,7 +764,7 @@ def tensor_to_list(input: mx.array) -> NestedFloatOrIntOrBoolList: return input.tolist() # type: ignore -def arange(*args, device: str, precision: int) -> mx.array: +def arange(*args: int | float, device: str, precision: int) -> mx.array: out = mx.arange(*args) return utils.handle_data_precision(out, precision) @@ -846,12 +854,12 @@ def polynomial_features(input: mx.array, *, degree: int = 2) -> mx.array: ) -def isnan(input): +def isnan(input: mx.array) -> mx.array: return mx.isnan(input) def nan_to_num( - input, + input: mx.array, nan: int | float | None, posinf: int | float | None, neginf: int | float | None, diff --git a/mithril/backends/with_autograd/mlx_backend/utils.py b/mithril/backends/with_autograd/mlx_backend/utils.py index 2be02d57..04291eca 100644 --- a/mithril/backends/with_autograd/mlx_backend/utils.py +++ b/mithril/backends/with_autograd/mlx_backend/utils.py @@ -14,7 +14,7 @@ from collections.abc import Callable, Sequence from functools import partial -from typing import Any +from typing import Any, TypeGuard import mlx.core as mx import mlx.nn as nn @@ -26,7 +26,7 @@ ArrayType = mx.array -dtype_map = { +dtype_map: dict[str | None, Any] = { "int8": mx.int8, "int16": mx.int16, "short": mx.int16, @@ -51,7 +51,13 @@ def get_device(device: str): return mx.Device(getattr(mx, device), 0) -def creation_fn_wrapper(*args, fn: Callable, dtype=None, precision: int, **kwargs): +def creation_fn_wrapper( + *args: Any, + fn: Callable[..., mx.array], + dtype: core.Dtype | mx.Dtype | None = None, + precision: int, + **kwargs: Any, +): if dtype is not None: dtype = handle_dtype(dtype) data = fn(*args, dtype=dtype, **kwargs) @@ -62,7 +68,12 @@ def creation_fn_wrapper(*args, fn: Callable, dtype=None, precision: int, **kwarg def conversion_fn_wrapper( - data, *args, fn: Callable, precision: int, dtype=None, **kwargs + data: Any, + *args: Any, + fn: Callable[..., mx.array], + precision: int, + dtype: mx.Dtype | None = None, + **kwargs: Any, ): if dtype is not None: dtype = handle_dtype(dtype) @@ -100,15 +111,14 @@ def handle_data_precision(data: mx.array, precision: int) -> mx.array: def handle_data_dtype(data: mx.array, dtype: core.Dtype | int) -> mx.array: - if isinstance(dtype, int): - dtype = core.Dtype(dtype) + dtype = core.Dtype(dtype) if data.dtype != dtype_map[dtype.name]: return data.astype(dtype_map[dtype.name]) return data -def polynomial_features_helper(arr1, arr2): +def polynomial_features_helper(arr1: mx.array, arr2: mx.array): # TODO: Consider using this function also in robust power. broadcasted_shape = np.broadcast_shapes(arr1.shape, arr2.shape) arr1 = mx.broadcast_to(arr1, broadcasted_shape) @@ -117,17 +127,28 @@ def polynomial_features_helper(arr1, arr2): return mx.where(cond, mx.array(1.0, dtype=arr1.dtype), arr1**arr2) -def squeeze_padding(padding): +def squeeze_padding( + padding: Sequence[int] | Sequence[Sequence[int]] | int, +) -> tuple[int, int]: # TODO: When Mlx support properly (4 edge)padding remove this func. - if isinstance(padding, Sequence) and isinstance(padding[0], Sequence): - if padding[0][0] == padding[0][0] and padding[1][0] == padding[1][1]: + + def _is_padding_nested(padding: Any) -> TypeGuard[Sequence[Sequence[int]]]: + return isinstance(padding, Sequence) and isinstance(padding[0], Sequence) + + if _is_padding_nested(padding): + if padding[0][0] == padding[0][1] and padding[1][0] == padding[1][1]: return (padding[0][0], padding[1][0]) else: raise RuntimeError(f"Mlx backend does not support padding: {padding}") - return padding + return padding # type: ignore -def unary_conditional_run(inp, cond, true_fun, false_fun): +def unary_conditional_run( + inp: mx.array, + cond: mx.array, + true_fun: Callable[[mx.array], mx.array], + false_fun: Callable[[mx.array], mx.array], +) -> mx.array: cond = mx.broadcast_to(cond, inp.shape) true_con_flat = mx.flatten(cond) false_cond_flat = mx.logical_not(true_con_flat) @@ -166,7 +187,7 @@ def tsne_softmax( def calc_prob_matrix( - negative_dist_sq: mx.array, sigmas: mx.array, zero_index=None + negative_dist_sq: mx.array, sigmas: mx.array, zero_index: int | None = None ) -> mx.array: """Convert a distances matrix to a matrix of probabilities. Parameters @@ -246,7 +267,7 @@ def find_optimal_sigmas( mx.array Returns optimal sigma values. """ - sigmas = [] + sigmas: list[float] = [] # Make fn that returns perplexity of this row given sigma def eval_fn(sigma, i): @@ -264,7 +285,7 @@ def eval_fn(sigma, i): return mx.array(sigmas, dtype=negative_dist_sq.dtype) -def log_sigmoid(input: mx.array, log: Callable, robust: bool): +def log_sigmoid(input: mx.array, log: Callable[..., mx.array], robust: bool): min = mx.minimum(0, input) input = mx.exp(-mx.abs(input)) if not robust: @@ -272,22 +293,24 @@ def log_sigmoid(input: mx.array, log: Callable, robust: bool): return min - log(1 + input) -def log_softmax(input: mx.array, log: Callable, robust: bool, axis: int = -1): +def log_softmax( + input: mx.array, log: Callable[..., mx.array], robust: bool, axis: int = -1 +) -> mx.array: if not robust: return nn.log_softmax(input, axis) return input - log(mx.exp(input).sum(axis=axis, keepdims=True)) -def calculate_binary_class_weight(labels): +def calculate_binary_class_weight(labels: mx.array) -> mx.array: return (1 - labels.mean()) / labels.mean() -def calculate_categorical_class_weight(labels, num_classes: int): +def calculate_categorical_class_weight(labels: mx.array, num_classes: int): one_hot = mx.eye(num_classes)[labels] return calculate_class_weight(one_hot) -def calculate_class_weight(labels): +def calculate_class_weight(labels: mx.array) -> mx.array: return ( (1 / labels.sum(axis=tuple(i for i in range(labels.ndim) if i != 1))) * labels.sum() @@ -330,10 +353,10 @@ def calculate_cross_entropy_class_weights( def get_submatrices1d( input: mx.array, - output_size, - kernel_width_size, + output_size: Sequence[int], + kernel_width_size: int, padding: int | tuple[int, int] = 0, - stride=1, + stride: int = 1, ): if isinstance(padding, tuple): input = mx.pad(input, ((0, 0), (0, 0), (padding[0], padding[1]))) @@ -356,11 +379,11 @@ def get_submatrices1d( def get_submatrices2d( input: mx.array, - output_size, - kernel_height_size, - kernel_width_size, + output_size: Sequence[int], + kernel_height_size: int, + kernel_width_size: int, padding: int | tuple[tuple[int, int], tuple[int, int]] = 0, - stride=1, + stride: int = 1, ): if isinstance(padding, tuple): input = mx.pad( @@ -396,7 +419,7 @@ def get_submatrices2d( ) -def get_type(input: int | float | bool | Sequence, precision: int): +def get_type(input: int | float | bool | Sequence[Any], precision: int) -> mx.Dtype: type = find_dominant_type(input).__name__ if type == "bool": return mx.bool_ # type: ignore diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index 2ada2b94..9a2aa763 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -53,7 +53,10 @@ class TorchBackend(ParallelBackend[torch.Tensor]): primitive_fn_path = "mithril.backends.with_autograd.torch_backend.ops" def __init__( - self, device: str = "cpu", precision: int = 32, device_mesh=None + self, + device: str = "cpu", + precision: int = 32, + device_mesh: tuple[int, ...] | None = None, ) -> None: self._device = device self._precision = precision @@ -94,7 +97,7 @@ def get_backend_array_type(self): return torch.Tensor @staticmethod - def register_primitive(fn: Callable) -> None: + def register_primitive(fn: Callable[..., Any]) -> None: TorchBackend.registered_primitives[fn.__name__] = fn @staticmethod @@ -136,7 +139,9 @@ def empty_cache(self) -> None: pass # print(f"Warning: empty_cache is not implemented for {self.device_type}") - def _creation_fn_wrapper(self, fn: Callable) -> Callable: + def _creation_fn_wrapper( + self, fn: Callable[..., torch.Tensor] + ) -> Callable[..., torch.Tensor]: """ Wrapper for PyTorch tensor creation functions. @@ -166,7 +171,9 @@ def _creation_fn_wrapper(self, fn: Callable) -> Callable: return array_creation_fn - def _conversion_fn_wrapper(self, fn: Callable) -> Callable: + def _conversion_fn_wrapper( + self, fn: Callable[..., torch.Tensor] + ) -> Callable[..., torch.Tensor]: """ Wrapper for PyTorch tensor conversion functions. @@ -196,7 +203,11 @@ def _conversion_fn_wrapper(self, fn: Callable) -> Callable: return array_conversion_fn def _parallelize( - self, *args, fn: Callable, device_mesh, **kwargs + self, + *args: Any, + fn: Callable[..., torch.Tensor], + device_mesh: tuple[int] | None, + **kwargs: Any, ) -> DTensor | torch.Tensor: """ Parallelizes the function's return tensor across devices. @@ -224,7 +235,7 @@ def _parallelize( ) def _register_callable( - self, fn: Callable | partial, fn_name: str, jit: bool = False + self, fn: Callable[..., torch.Tensor], fn_name: str, jit: bool = False ): """ Register a callable function with the backend. @@ -272,7 +283,7 @@ def __getstate__(self) -> object: del state["_parallel_manager"] return state - def __setstate__(self, state) -> None: + def __setstate__(self, state: dict[Any, Any]) -> None: self.__dict__.update(state) # Recreate the parallel manager if self._raw_device_mesh is not None: @@ -416,10 +427,10 @@ def rand_uniform( def _arange( self, - *args, + *args: int | float, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, - **kwargs, + **kwargs: int | float, ) -> torch.Tensor: _dtype: str | None = None if isinstance(dtype, Dtype): @@ -568,7 +579,7 @@ def transpose( return ops.transpose(input, axes) def unique( - self, input, **kwargs + self, input: torch.Tensor, **kwargs: Any ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: return torch.unique(input, **kwargs) @@ -581,25 +592,21 @@ def topk(self, input: torch.Tensor, k: int) -> torch.Tensor: return torch.topk(input, k)[0] # TODO: Returns different tuple type??? def multinomial( - self, - probs: torch.Tensor, - num_samples: int, - replacement: bool = False, - **kwargs, + self, probs: torch.Tensor, num_samples: int, replacement: bool = False ) -> torch.Tensor: - return torch.multinomial(probs, num_samples, replacement, **kwargs) + return torch.multinomial(probs, num_samples, replacement) - def jit(self, *args, **kwargs): + def jit(self, *args: Any, **kwargs: Any): backend = "inductor" if "mps" in self._device: backend = "aot_eager" return torch.compile(*args, backend=backend, **kwargs) - def grad(self, fn: Callable) -> Callable: + def grad(self, fn: Callable[..., torch.Tensor]): return torch_grad(fn) def value_and_grad( - self, fn: Callable + self, fn: Callable[..., torch.Tensor] ) -> Callable[..., tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]: return torch_grad_and_value(fn) diff --git a/mithril/backends/with_autograd/torch_backend/ops.py b/mithril/backends/with_autograd/torch_backend/ops.py index 0486c0b4..70c3709d 100644 --- a/mithril/backends/with_autograd/torch_backend/ops.py +++ b/mithril/backends/with_autograd/torch_backend/ops.py @@ -454,8 +454,7 @@ def conv1d_bias( padding: tuple[int, int] = (1, 1), dilation: tuple[int, int] = (1, 1), ) -> torch.Tensor: - if isinstance(padding, Sequence): - input = F.pad(input, [padding[0], padding[1]], "constant", 0) + input = F.pad(input, [padding[0], padding[1]], "constant", 0) return torch.nn.functional.conv1d( input=input, @@ -531,8 +530,7 @@ def max_pool1d( padding: tuple[int, int] = (0, 0), dilation: int = 1, ) -> torch.Tensor: - if isinstance(padding, Sequence): - input = F.pad(input, [padding[0], padding[1]], "constant", 0) + input = F.pad(input, [padding[0], padding[1]], "constant", 0) return F.max_pool1d( input, @@ -617,7 +615,7 @@ def cross_entropy( categorical: bool = True, robust: bool = False, ) -> torch.Tensor: - log: partial | Callable = ( + log: partial[torch.Tensor] | Callable[..., torch.Tensor] = ( partial(robust_log, cutoff=cutoff) if robust else torch.log ) _weights = calculate_cross_entropy_class_weights( @@ -714,9 +712,10 @@ def binary_cross_entropy( pos_weight: bool | float = 1.0, robust: bool = False, ) -> torch.Tensor: - log: partial | Callable = ( + log: partial[torch.Tensor] | Callable[..., torch.Tensor] = ( partial(robust_log, cutoff=cutoff) if robust else torch.log ) + _pos_weight: torch.Tensor | bool | float # TODO: Use F.binary_cross_entropy if isinstance(pos_weight, bool) and pos_weight: _pos_weight = calculate_binary_class_weight(target) @@ -734,7 +733,7 @@ def binary_cross_entropy_with_logits( pos_weight: bool | float = 1.0, robust: bool = False, ) -> torch.Tensor: - log: partial | Callable = ( + log: partial[torch.Tensor] | Callable[..., torch.Tensor] = ( partial(robust_log, cutoff=cutoff) if robust else torch.log ) _pos_weight: torch.Tensor | None @@ -751,7 +750,6 @@ def binary_cross_entropy_with_logits( return F.binary_cross_entropy_with_logits( input, target, reduction="none", pos_weight=_pos_weight ) - if _pos_weight is not None: log_weight = (_pos_weight - 1) * (target) + 1 loss = (1 - target) * input - (log_weight * log_sigmoid(input, log, robust)) @@ -797,7 +795,7 @@ def kl_divergence( ) -def eye(N: int, M: int, *, device: str, precision: int) -> torch.Tensor: +def eye(N: int, M: int | None, *, device: str, precision: int) -> torch.Tensor: if M is None: return handle_data_precision(torch.eye(N, device=device), precision) else: @@ -808,7 +806,9 @@ def transposed_diag(input: torch.Tensor) -> torch.Tensor: return torch.diag(input)[:, None] -def ones_with_zero_diag(N: int, M: int, device: str, precision: int) -> torch.Tensor: +def ones_with_zero_diag( + N: int, M: int | None, device: str, precision: int +) -> torch.Tensor: if M is None: output = torch.ones(N) - torch.eye(N) else: @@ -881,7 +881,7 @@ def to_parallel(tensor: torch.Tensor, device_mesh: DeviceMesh) -> torch.Tensor: ) -def arange(*args, device: torch.device, precision: int) -> torch.Tensor: +def arange(*args: int | float, device: torch.device, precision: int) -> torch.Tensor: return handle_data_precision(torch.arange(*args, device=device), precision) @@ -915,10 +915,8 @@ def size( return math.prod(input.size()) if isinstance(dim, int): return input.size(dim) - if isinstance(dim, Sequence): - return tuple(input.size(idx) for idx in dim) else: - raise ValueError(f"Unexpected dim: {dim}") + return tuple(input.size(idx) for idx in dim) def norm_modifier(input: torch.Tensor) -> torch.Tensor: @@ -1000,7 +998,9 @@ def polynomial_features(input: torch.Tensor, *, degree: int = 2) -> torch.Tensor input, ) ) - powers: Iterator = map(sum, combinations_with_replacement(identity, degree)) + powers: Iterator[torch.Tensor | int] = map( + sum, combinations_with_replacement(identity, degree) + ) # Skip first element of powers. This is the bias term. next(powers) return torch.hstack([(data**p).prod(1)[:, None] for p in powers]) diff --git a/mithril/backends/with_autograd/torch_backend/parallel.py b/mithril/backends/with_autograd/torch_backend/parallel.py index 8b194b77..acea8631 100644 --- a/mithril/backends/with_autograd/torch_backend/parallel.py +++ b/mithril/backends/with_autograd/torch_backend/parallel.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import atexit import multiprocessing as mp import socket @@ -55,7 +57,7 @@ class TorchParallel(Parallel[torch.Tensor]): used_ports: set[str] = set() device_meshes: dict[tuple[int, ...], DeviceMesh] = {} - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any): if not cls._instance: cls._instance = super().__new__(cls) return cls._instance @@ -98,7 +100,9 @@ def _init_processes(self): for process in processes: process.start() - self.data_queues: list[mp.Queue] = data_queues + self.data_queues: list[mp.Queue[str | Callable[..., torch.Tensor]]] = ( + data_queues + ) self.tensor_id_ref: dict[int, int] = {} while not self.initialized: port_name = self.get_portname() @@ -125,7 +129,7 @@ def _init_device_mesh(self, mesh_shape: tuple[int, ...]) -> DeviceMesh: return TorchParallel.device_meshes[mesh_shape] - def run_callable(self, *primals, fn_name: str): + def run_callable(self, *primals: Any, fn_name: str): primals_ref = apply_to_all_elems( lambda x: TensorRef(self.tensor_id_ref[id(x)]) if isinstance(x, STensor) @@ -149,7 +153,11 @@ def run_callable(self, *primals, fn_name: str): return res def register_callable( - self, fn: Callable, fn_name: str, base_mesh: DeviceMesh, jit: bool = False + self, + fn: Callable[..., torch.Tensor], + fn_name: str, + base_mesh: DeviceMesh, + jit: bool = False, ) -> int: assert self.data_queues is not None, "Parallel manager is not initialized!" @@ -278,7 +286,14 @@ def _send_instrcs( self.instruction_queue.write(instruction.value, op_index, args, kwargs) - def _store_tensors(self, data): + def _store_tensors( + self, + data: dict[str, STensor | DTensor] + | tuple[STensor | DTensor, ...] + | list[STensor | DTensor] + | STensor + | DTensor, + ): match data: case dict(): return {key: self._store_tensors(value) for key, value in data.items()} @@ -298,8 +313,8 @@ def _tensor_callback( self, instruction: Instructions, op_name: str, - args: Any, - kwargs: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], ): if self.is_alive is False: return @@ -338,7 +353,7 @@ def _initilize_parallel(self, rank: int, device: str, port_name: str): STensor._callback = self._tensor_callback self.initialized = True - def _replicate_cache(self, fn: partial, device_mesh: DeviceMesh): + def _replicate_cache(self, fn: partial[Any], device_mesh: DeviceMesh): # Replicates cache data partially provided to evaluate and evaluate_gradients. if "cache" not in fn.keywords: return fn @@ -362,11 +377,14 @@ def _replicate_cache(self, fn: partial, device_mesh: DeviceMesh): fn.keywords["cache"] = cache_replicated return fn - def _process(self, rank: int, data_queue: mp.Queue): + def _process( + self, rank: int, data_queue: mp.Queue[str | Callable[..., torch.Tensor]] + ): self.tensor_ref: dict[int, DTensor | torch.Tensor] = {} while not self.initialized: port_name = data_queue.get() + assert isinstance(port_name, str) try: self._initilize_parallel(rank, self.device, port_name) except Exception as e: @@ -381,32 +399,31 @@ def _process(self, rank: int, data_queue: mp.Queue): match base_instruction: case Instructions.RUN_OP: op_name = self.op_list[op_index] - args = apply_to_all_elems( + _args = apply_to_all_elems( lambda x: self.tensor_ref[x.id] if isinstance(x, TensorRef) else x, args, ) - kwargs = apply_to_all_elems( + _kwargs = apply_to_all_elems( lambda x: self.tensor_ref[x.id] if isinstance(x, TensorRef) else x, kwargs, ) - result = getattr(torch_ops.aten, op_name)(*args, **kwargs) + result = getattr(torch_ops.aten, op_name)(*_args, **_kwargs) self.tensor_ref[self.tensor_counter] = ( result # Result directly saved ) self.tensor_counter += 1 - case Instructions.FULL_TENSOR: - tensor = apply_to_all_elems( + _tensor = apply_to_all_elems( lambda x: self.tensor_ref[x.id] if isinstance(x, TensorRef) else x, args, )[0] - tensor.full_tensor() + _tensor.full_tensor() case Instructions.REGISTER_CALLABLE: apply_jit = args[0] @@ -421,6 +438,7 @@ def _process(self, rank: int, data_queue: mp.Queue): base_mesh = TorchParallel.device_meshes[base_mesh] fn = data_queue.get() + assert not isinstance(fn, str) dist.barrier(self.communication_group) @@ -433,12 +451,14 @@ def _process(self, rank: int, data_queue: mp.Queue): cache_refs, ) fn.keywords["cache"] = caches - fn = self._replicate_cache(fn, base_mesh) + _fn = self._replicate_cache(fn, base_mesh) + self.callables[fn_name] = _fn if apply_jit == 1: - fn = torch.compile(fn) - - self.callables[fn_name] = fn + _fn = torch.compile(fn) + self.callables[fn_name] = _fn + else: + self.callables[fn_name] = fn case Instructions.RUN_REGISTERED: fn = self.callables[kwargs["fn_name"]] diff --git a/mithril/backends/with_autograd/torch_backend/utils.py b/mithril/backends/with_autograd/torch_backend/utils.py index 8fd51a74..272cd9e9 100644 --- a/mithril/backends/with_autograd/torch_backend/utils.py +++ b/mithril/backends/with_autograd/torch_backend/utils.py @@ -23,7 +23,7 @@ from enum import Enum from functools import partial from multiprocessing import shared_memory -from typing import Any +from typing import Any, overload import numpy as np import torch @@ -148,10 +148,10 @@ def find_optimal_sigmas( np.ndarray Returns optimal sigma values. """ - sigmas = [] + sigmas: list[float] = [] # Make fn that returns perplexity of this row given sigma - def eval_fn(sigma, i): + def eval_fn(sigma: float, i: int): return perplexity_fn(negative_dist_sq[i, :], torch.tensor(sigma), i, threshold) # For each row of the matrix (each point in our dataset) @@ -192,19 +192,19 @@ def handle_dtype(dtype: core.Dtype | torch.dtype | str) -> Any: return dtype_map[dtype.name] elif isinstance(dtype, torch.dtype): return dtype - elif isinstance(dtype, str) and dtype in dtype_map: + elif dtype in dtype_map: return dtype_map[dtype] raise TypeError(f"Provided data type '{dtype}' not understood") def creation_fn_wrapper_inner( - *args, - dtype=None, - fn: Callable, + *args: Any, + dtype: core.Dtype | torch.dtype | str | None = None, + fn: Callable[..., torch.Tensor], device: str, precision: int, device_mesh: tuple[int, ...] | None = None, - **kwargs, + **kwargs: Any, ): _device = get_device(device) if dtype is not None: @@ -218,8 +218,14 @@ def creation_fn_wrapper_inner( def conversion_fn_wrapper_inner( - data, *args, dtype=None, fn: Callable, device: str, precision: int, **kwargs -): + data: Any, + *args: Any, + dtype: torch.dtype | str | None = None, + fn: Callable[..., torch.Tensor], + device: str, + precision: int, + **kwargs: Any, +) -> torch.Tensor: _device = get_device(device) if dtype is not None: dtype = handle_dtype(dtype) @@ -253,6 +259,7 @@ def conversion_fn_wrapper_inner( def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: _dtype = data.dtype + dtype: torch.dtype # Do not make any changes to boolean types. if _dtype != torch.bool: if ( @@ -260,17 +267,18 @@ def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: and not torch.is_complex(data) and _dtype != getattr(torch, f"int{precision}") ): - data = data.type(getattr(torch, f"int{precision}")) + dtype = getattr(torch, f"int{precision}") + data = data.type(dtype) elif torch.is_floating_point(data) and _dtype != getattr( torch, f"float{precision}" ): - data = data.type(getattr(torch, f"float{precision}")) + dtype = getattr(torch, f"float{precision}") + data = data.type(dtype) return data def handle_data_dtype(data: ArrayType, dtype: core.Dtype | int) -> ArrayType: - if isinstance(dtype, int): - dtype = core.Dtype(dtype) + dtype = core.Dtype(dtype) if data.dtype != dtype_map[dtype.name]: as_type = dtype_map[dtype.name] @@ -292,7 +300,9 @@ def get_subtype(data: ArrayType) -> str: return "" -def calculate_tpr_fpr(threshold, input, label): +def calculate_tpr_fpr( + threshold: torch.Tensor, input: torch.Tensor, label: torch.Tensor +): input_c = input.clone() n_positive = (label == 1).sum() @@ -307,7 +317,9 @@ def calculate_tpr_fpr(threshold, input, label): return tpr, fpr -def log_sigmoid(input: torch.Tensor, log: Callable, robust: bool): +def log_sigmoid( + input: torch.Tensor, log: Callable[..., torch.Tensor], robust: bool +) -> torch.Tensor: min = torch.minimum(torch.tensor(0, device=input.device, dtype=input.dtype), input) input = torch.exp(-torch.abs(input)) if not robust: @@ -315,25 +327,29 @@ def log_sigmoid(input: torch.Tensor, log: Callable, robust: bool): return min - log(1 + input) -def log_softmax(input: torch.Tensor, log: Callable, robust: bool, axis: int = -1): +def log_softmax( + input: torch.Tensor, log: Callable[..., torch.Tensor], robust: bool, axis: int = -1 +): if not robust: return torch.log_softmax(input, dim=None) return input - log(torch.exp(input).sum(dim=axis, keepdim=True)) -def calculate_binary_class_weight(labels): +def calculate_binary_class_weight(labels: torch.Tensor) -> torch.Tensor: labels = labels.double() return (1 - labels.mean()) / labels.mean() -def calculate_categorical_class_weight(labels, num_classes: int): +def calculate_categorical_class_weight( + labels: torch.Tensor, num_classes: int +) -> torch.Tensor: one_hot = torch.eye(num_classes)[labels] return calculate_class_weight(one_hot) -def calculate_class_weight(labels): +def calculate_class_weight(labels: torch.Tensor) -> torch.Tensor: return ( - (1 / labels.sum(axis=tuple(i for i in range(labels.ndim) if i != 1))) + (1 / labels.sum(dim=tuple(i for i in range(labels.ndim) if i != 1))) * labels.sum() / labels.shape[1] ) @@ -402,7 +418,34 @@ def init_dist_group(rank: int, world_size: int, device: str = "cpu", port: str = dist.init_process_group(backend=backend_type, rank=rank, world_size=world_size) -def apply_to_all_elems(fn: Callable, data: Any): +# TODO: Reconsider this overload logic after python supports +# intersection and negation in typing https://github.com/python/typing/issues/213 + + +@overload +def apply_to_all_elems[T1, T2]( + fn: Callable[[T1], T2], data: dict[Any, T1] +) -> dict[Any, T2]: ... + + +@overload +def apply_to_all_elems[T1, T2](fn: Callable[[T1], T2], data: list[T1]) -> list[T2]: ... + + +@overload +def apply_to_all_elems[T1, T2]( + fn: Callable[[T1], T2], data: tuple[T1, ...] +) -> tuple[T2, ...]: ... + + +@overload +def apply_to_all_elems[T1, T2](fn: Callable[[T1], T2], data: T1) -> T2: ... + + +def apply_to_all_elems[T1, T2]( + fn: Callable[[T1], T2], + data: T1 | dict[Any, T1] | tuple[T1, ...] | list[T1], +) -> T2 | dict[Any, T2] | tuple[T2, ...] | list[T2]: if isinstance(data, dict): return {key: apply_to_all_elems(fn, value) for key, value in data.items()} elif isinstance(data, tuple): @@ -463,9 +506,15 @@ def __init__(self, nprocesses: int): atexit.register(self._cleanup) - def write(self, opcode1: int, opcode2: int, args: Any = None, kwargs=None) -> None: + def write( + self, + opcode1: int, + opcode2: int, + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + ) -> None: if args is None: - args = [] + args = tuple() if kwargs is None: kwargs = {} @@ -490,7 +539,7 @@ def write(self, opcode1: int, opcode2: int, args: Any = None, kwargs=None) -> No self._index = self._next_index() self._shm.buf[0:1] = self._index.to_bytes() # Write writer index - def read(self, rank: int) -> tuple[int, int, Any, Any]: + def read(self, rank: int) -> tuple[int, int, tuple[Any, ...], dict[str, Any]]: if not (0 <= self._index < self.NUM_ELEMENTS): raise IndexError("Index out of range.") @@ -547,7 +596,9 @@ def _write_memory( + len(kwargs_bytes) ] = kwargs_bytes - def _read_memory(self, index: int) -> tuple[int, int, Any, Any]: + def _read_memory( + self, index: int + ) -> tuple[int, int, tuple[Any, ...], dict[str, Any]]: offset = self._nprocesses + index * self.PAIR_SIZE opcode1, opcode2 = struct.unpack("2i", self._shm.buf[offset : offset + 8]) args_identifier = self._shm.buf[offset + 8 : offset + 12] @@ -589,7 +640,7 @@ def _encode_args_kwargs(self, args: Any, kwargs: Any) -> tuple[bytes, bytes, byt def _decode_args_kwargs( self, offset: int, args_identifier: bytes - ) -> tuple[Any, dict]: + ) -> tuple[tuple[Any, ...], dict[str, Any]]: b_args_identifier = bin(int.from_bytes(args_identifier, "little"))[2:].zfill(32) args_length = int(b_args_identifier[2:17], 2) kwargs_length = int(b_args_identifier[17:], 2) - args_length @@ -615,18 +666,17 @@ def _decode_args(self, offset: int, args_length: int, pickled: bool) -> Any: self._shm.buf[offset : offset + args_length], args_length // 12 ) - def _decode_kwargs(self, offset: int, args_length: int, kwargs_length: int) -> dict: + def _decode_kwargs( + self, offset: int, args_length: int, kwargs_length: int + ) -> dict[str, Any]: return pickle.loads( self._shm.buf[offset + args_length : offset + args_length + kwargs_length] ) - def _args_to_bytes(self, args: Iterable) -> bytes: + def _args_to_bytes(self, args: Iterable[int | float | TensorRef]) -> bytes: args_bytes = b"" - if isinstance(args, Sequence): - for elem in args: - args_bytes += self._value_to_byte(elem) - else: - raise ValueError("Args must be iterable!") + for elem in args: + args_bytes += self._value_to_byte(elem) return args_bytes @@ -704,7 +754,10 @@ def check_device_mesh(base_mesh: DeviceMesh, device_mesh: tuple[int, ...]): ) -def get_type(input: int | float | bool | Sequence, precision: int): +NestedTensorType = int | float | bool | Sequence["NestedTensorType"] + + +def get_type(input: NestedTensorType, precision: int): type = find_dominant_type(input).__name__ if type == "bool": return torch.bool diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index c42758d5..b09dadfa 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import numpy as np from .... import core @@ -61,14 +63,16 @@ def zeros( _shape = process_shape(shape) return array.zeros(_shape) - def zeros_like(self, array, dtype: core.Dtype | None = None) -> PyArray: + def zeros_like(self, input: PyArray, dtype: core.Dtype | None = None) -> PyArray: assert dtype is None, "dtype is not supported in CBackend" - return self.array(np.zeros(array.shape, dtype=np.float32)) + return self.array(np.zeros(input.shape, dtype=np.float32)) - def to_numpy(self, array: PyArray) -> np.ndarray: + def to_numpy(self, array: PyArray) -> np.ndarray[Any, Any]: return utils.to_numpy(array) - def array(self, np_array: np.ndarray, dtype: core.Dtype | None = None) -> PyArray: + def array( + self, input: np.ndarray[Any, Any], dtype: core.Dtype | None = None + ) -> PyArray: assert dtype is None, "dtype is not supported in CBackend" - np_array = np_array.astype(np.float32) - return utils.from_numpy(np_array) + input = input.astype(np.float32) + return utils.from_numpy(input) diff --git a/mithril/backends/with_manualgrad/c_backend/src/array.py b/mithril/backends/with_manualgrad/c_backend/src/array.py index 8fdb9b47..8a28d250 100644 --- a/mithril/backends/with_manualgrad/c_backend/src/array.py +++ b/mithril/backends/with_manualgrad/c_backend/src/array.py @@ -14,6 +14,7 @@ import ctypes import os +from typing import Any current_file_path = os.path.abspath(__file__) @@ -50,7 +51,7 @@ class Array(ctypes.Structure): lib.delete_struct.argtypes = [ctypes.POINTER(Array)] -def to_c_int_array(lst): +def to_c_int_array(lst: list[Any]): return (ctypes.c_int * len(lst))(*lst) diff --git a/mithril/backends/with_manualgrad/c_backend/src/array.pyi b/mithril/backends/with_manualgrad/c_backend/src/array.pyi index 175023c9..fd0a17a4 100644 --- a/mithril/backends/with_manualgrad/c_backend/src/array.pyi +++ b/mithril/backends/with_manualgrad/c_backend/src/array.pyi @@ -14,6 +14,7 @@ import builtins import ctypes +from collections.abc import Sequence from types import EllipsisType from typing import Any, overload @@ -32,6 +33,7 @@ class PyArray: ndim: int def data(self) -> NestedList: ... + def __init__(self, arr: Array, shape: tuple[int, ...] | list[int]) -> None: ... def __gt__(self, other: PyArray) -> PyArray: ... def __ge__(self, other: PyArray) -> PyArray: ... def __lt__(self, other: PyArray) -> PyArray: ... @@ -41,31 +43,26 @@ class PyArray: def __xor__(self, other: PyArray) -> PyArray: ... def __invert__(self) -> PyArray: ... def __matmul__(self, other: PyArray) -> PyArray: ... - @overload - def __truediv__(self, other: builtins.int) -> PyArray: ... - @overload - def __truediv__(self, other: PyArray) -> PyArray: ... - @overload - def __div__(self, other: builtins.int) -> PyArray: ... - @overload - def __div__(self, other: PyArray) -> PyArray: ... - @overload - def __floordiv__(self, other: builtins.int) -> PyArray: ... - @overload - def __floordiv__(self, other: PyArray) -> PyArray: ... def __lshift__(self, other: PyArray) -> PyArray: ... def __rshift__(self, other: PyArray) -> PyArray: ... # Construction/Conversion methods def detach(self) -> PyArray: ... def numpy(self) -> Any: ... # Returns numpy.ndarray - def tolist(self) -> list: ... + def tolist(self) -> list[Any]: ... # Basic arithmetic operations def __add__(self, other: PyArray | builtins.float | builtins.int) -> PyArray: ... def __sub__(self, other: PyArray | builtins.float | builtins.int) -> PyArray: ... def __mul__(self, other: PyArray | builtins.float | builtins.int) -> PyArray: ... def __pow__(self, other: PyArray | builtins.float | builtins.int) -> PyArray: ... + def __div__(self, other: PyArray | builtins.float | builtins.int) -> PyArray: ... + def __truediv__( + self, other: PyArray | builtins.float | builtins.int + ) -> PyArray: ... + def __floordiv__( + self, other: PyArray | builtins.float | builtins.int + ) -> PyArray: ... def __neg__(self) -> PyArray: ... # Inplace operations @@ -80,7 +77,10 @@ class PyArray: ) -> PyArray: ... def squeeze(self, dim: builtins.int | None = None) -> PyArray: ... def unsqueeze(self, dim: builtins.int) -> PyArray: ... - def transpose(self, dim0: builtins.int, dim1: builtins.int) -> PyArray: ... + @overload + def transpose(self, axes: None | Sequence[int]) -> PyArray: ... + @overload + def transpose(self, *axes: int) -> PyArray: ... def permute(self, *dims: builtins.int) -> PyArray: ... def flatten( self, start_dim: builtins.int = 0, end_dim: builtins.int = -1 @@ -101,7 +101,7 @@ class PyArray: ) -> PyArray: ... def __setitem__( self, - idx: builtins.int | slice | PyArray | tuple, + idx: builtins.int | slice | PyArray | tuple[Any, ...], val: PyArray | builtins.float | builtins.int, ) -> None: ... def index_select(self, dim: builtins.int, index: PyArray) -> PyArray: ... diff --git a/mithril/backends/with_manualgrad/c_backend/src/utils.py b/mithril/backends/with_manualgrad/c_backend/src/utils.py index 9b2fadb7..17fe0886 100644 --- a/mithril/backends/with_manualgrad/c_backend/src/utils.py +++ b/mithril/backends/with_manualgrad/c_backend/src/utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + import numpy as np from .array import Array, PyArray, lib, to_c_float_array, to_c_int_array @@ -21,11 +23,11 @@ def to_numpy(array: PyArray): return np.ctypeslib.as_array(array.arr.contents.data, shape=(array.shape)) -def from_numpy(array: np.ndarray): +def from_numpy(array: np.ndarray[Any, Any]) -> PyArray: shape = array.shape ndim = len(shape) c_shape = to_c_int_array(shape) c_data = to_c_float_array(array) arr: Array = lib.create_struct(c_data, ndim, c_shape) - return PyArray(arr, shape) # type: ignore + return PyArray(arr, shape) diff --git a/mithril/backends/with_manualgrad/common_primitives.py b/mithril/backends/with_manualgrad/common_primitives.py index 104fbece..fe90fa9b 100644 --- a/mithril/backends/with_manualgrad/common_primitives.py +++ b/mithril/backends/with_manualgrad/common_primitives.py @@ -179,8 +179,10 @@ def buffer(input: DataType, cache: CacheType = None): return input -def permute_tensor(input: DataType, indices: DataType, cache: CacheType = None): - return input[indices] # type: ignore +def permute_tensor( + input: DataType, indices: DataType, cache: CacheType = None +) -> DataType: + return input[indices] def reshape(input: DataType, shape: tuple[int, ...], cache: CacheType = None): diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index fca06f99..62add425 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -21,6 +21,7 @@ from ....core import Dtype from ...backend import Backend, PadWidthType from ...utils import process_shape +from ..common_primitives import CacheType from . import ops, ops_grad, utils @@ -100,7 +101,9 @@ def set_seed(self, seed: int): self.seed = seed np.random.seed(seed) - def _creation_fn_wrapper(self, fn: Callable) -> Callable: + def _creation_fn_wrapper( + self, fn: Callable[..., np.ndarray[Any, Any]] + ) -> Callable[..., np.ndarray[Any, Any]]: """ Wrapper for NumPy array creation functions. @@ -120,7 +123,9 @@ def _creation_fn_wrapper(self, fn: Callable) -> Callable: """ return partial(utils.creation_fn_wrapper, fn=fn, precision=self.precision) - def _conversion_fn_wrapper(self, fn: Callable) -> Callable: + def _conversion_fn_wrapper( + self, fn: Callable[..., np.ndarray[Any, Any]] + ) -> Callable[..., np.ndarray[Any, Any]]: """ Wrapper for NumPy array conversion functions. @@ -142,10 +147,16 @@ def _conversion_fn_wrapper(self, fn: Callable) -> Callable: """ return partial(utils.conversion_fn_wrapper, fn=fn, precision=self.precision) - def accumulate_grads(self, gradient: np.ndarray, input: np.ndarray, cache, idx): + def accumulate_grads( + self, + gradient: np.ndarray[Any, Any], + input: np.ndarray[Any, Any], + cache: CacheType | None, + idx: int, + ) -> np.ndarray[Any, Any]: return utils.accumulate_grads(gradient, input, cache, idx) - def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray: + def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]: _dtype: str | None = None if isinstance(dtype, Dtype): _dtype = dtype.name @@ -198,7 +209,7 @@ def randn( *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None, prng_key: Any = None, - ) -> np.ndarray: + ) -> np.ndarray[Any, Any]: _dtype: str | None = None if isinstance(dtype, Dtype): _dtype = dtype.name @@ -401,7 +412,7 @@ def topk(self, array: np.ndarray, k: int) -> np.ndarray: return values def multinomial( - self, probs: np.ndarray, num_samples: int, replacement: bool = False, **kwargs + self, probs: np.ndarray, num_samples: int, replacement: bool = False ) -> np.ndarray: # input = np.asarray(probs) if probs.ndim == 1: diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index c64e28dc..3b1735df 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Iterator, Sequence from functools import partial from itertools import combinations_with_replacement +from typing import Any import numpy as np import scipy.linalg as slin # type: ignore[import-untyped] @@ -211,53 +212,78 @@ ] +# TODO: Type annotations of numpy functions are written as np.ndarray[Any, Any] for now, +# However, it can be annotated more precisely in some functions +# (example: np.ndarray[tuple[int, int, *tuple[int, ...]], np.dtype[np.float32]]). +# Above example annotates given arg will have at least two dimensions and +# it has np.float32 dtype. This kind of annotations can be added in the future. + + # Ops -def exp(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def exp( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: output = np.exp(input) return output -def sqrt(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def sqrt( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: output = np.sqrt(input) return output -def sin(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def sin( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: output = np.sin(input) return output -def cos(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def cos( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: output = np.cos(input) return output -def abs(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def abs( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.abs(input) -def sign(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def sign( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.sign(input) -def log(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def log( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.log(input) -def unique(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def unique( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.unique(input) -def trapezoid(y: np.ndarray, x: np.ndarray | None = None) -> np.float64 | np.ndarray: +def trapezoid( + y: np.ndarray[Any, Any], x: np.ndarray[Any, Any] | None = None +) -> np.float64 | np.ndarray[Any, Any]: return np.trapezoid(y, x) def robust_power( - base: np.ndarray, - exponent: np.ndarray, - threshold: np.ndarray, + base: np.ndarray[Any, Any], + exponent: np.ndarray[Any, Any], + threshold: np.ndarray[tuple[()], Any], cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: # Broadcasting threshold for shape calculations threshold = np.resize(threshold, exponent.shape) # cond = (base < threshold) & (exponent < 1.0) @@ -277,8 +303,10 @@ def robust_power( # undefined points (log(0) = -inf in this case), # further testing should be done about performance def robust_sqrt( - input: np.ndarray, cutoff: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + cutoff: np.ndarray[tuple[()], Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: input = np.abs(input) inds = input < cutoff output = np.zeros_like(input) @@ -288,8 +316,10 @@ def robust_sqrt( def robust_log( - input: np.ndarray, cutoff: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + cutoff: np.ndarray[tuple[()], Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: input = np.abs(input) inds = input < cutoff y_c = np.log(cutoff) @@ -304,8 +334,10 @@ def robust_log( # undefined points (f(0) = inf in this case), # futher testing should be done. def stable_reciprocal( - input: np.ndarray, cutoff: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + cutoff: np.ndarray[tuple[()], Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: inds = np.abs(input) < cutoff y_c = np.reciprocal(cutoff) output = np.zeros_like(input) @@ -318,23 +350,31 @@ def stable_reciprocal( # Non linearity funcs -def relu(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def relu( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.maximum(np.array(0.0, dtype=input.dtype), input) def leaky_relu( - input: np.ndarray, slope: float | np.ndarray, cache: CacheType | None = None + input: np.ndarray[Any, Any], + slope: float | np.ndarray[Any, Any], + cache: CacheType | None = None, ): return np.maximum(np.array(0.0, dtype=input.dtype), input) + slope * np.minimum( np.array(0.0, dtype=input.dtype), input ) -def tanh(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def tanh( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.tanh(input) -def sigmoid(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def sigmoid( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: # For numerical stability implement sigmoid with respect to the # sign of input. mask = input >= 0 @@ -344,18 +384,22 @@ def sigmoid(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: return sig -def softplus(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def softplus( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: # See: https://stackoverflow.com/questions/44230635/avoid-overflow-with-softplus-function-in-python return np.log1p(np.exp(-np.abs(input))) + np.maximum(input, 0.0) -def gelu(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def gelu( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return input * (1 + erf(input / np.sqrt(2))) / 2 def softmax( - input: np.ndarray, *, axis: int = -1, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], *, axis: int = -1, cache: CacheType | None = None +) -> np.ndarray[Any, Any]: write_into_cache(cache, "axis", axis) input_tensor = input - np.max(input, axis=axis, keepdims=True) e = np.exp(input_tensor) @@ -365,17 +409,17 @@ def softmax( # Reduction ops def reduce_mean( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | tuple[int, ...] | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.mean(input, axis=axis, keepdims=keepdim) def reduce_sum( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | tuple[int, ...] | None = None, keepdim: bool = False, @@ -385,76 +429,76 @@ def reduce_sum( def reduce_max( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | tuple[int, ...] | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.max(input, axis=axis, keepdims=keepdim) def reduce_argmax( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.argmax(input, axis=axis, keepdims=keepdim) def reduce_min( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | tuple[int, ...] | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.min(input, axis=axis, keepdims=keepdim) def reduce_argmin( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.argmin(input, axis=axis, keepdims=keepdim) def reduce_prod( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | tuple[int, ...] | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.prod(input, axis=axis, keepdims=keepdim) def variance( - input: np.ndarray, + input: np.ndarray[Any, Any], *, axis: int | tuple[int, ...] | None = None, keepdim: bool = False, correction: float = 0.0, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.var(input, axis=axis, ddof=correction, keepdims=keepdim) # NN ops def conv1d( - input: np.ndarray, - kernel: np.ndarray, + input: np.ndarray[Any, Any], + kernel: np.ndarray[Any, Any], *, stride: int = 1, padding: tuple[int, int] = (1, 1), dilation: int = 1, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if dilation != 1: raise NotImplementedError( f"Dilation of {dilation} is not supported. " @@ -468,15 +512,15 @@ def conv1d( def conv1d_bias( - input: np.ndarray, - kernel: np.ndarray, - bias: np.ndarray, + input: np.ndarray[Any, Any], + kernel: np.ndarray[Any, Any], + bias: np.ndarray[Any, Any], *, stride: int = 1, padding: tuple[int, int] = (1, 1), dilation: int = 1, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return ( conv1d( input=input, @@ -491,14 +535,14 @@ def conv1d_bias( def conv2d( - input: np.ndarray, - kernel: np.ndarray, + input: np.ndarray[Any, Any], + kernel: np.ndarray[Any, Any], *, stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if dilation != (1, 1): raise NotImplementedError( f"Dilation of {dilation} is not supported. " @@ -522,15 +566,15 @@ def conv2d( def conv2d_bias( - input: np.ndarray, - kernel: np.ndarray, - bias: np.ndarray, + input: np.ndarray[Any, Any], + kernel: np.ndarray[Any, Any], + bias: np.ndarray[Any, Any], *, stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return ( conv2d( input=input, @@ -545,14 +589,14 @@ def conv2d_bias( def max_pool1d( - input: np.ndarray, + input: np.ndarray[Any, Any], kernel_size: int, stride: int, *, padding: tuple[int, int] = (0, 0), dilation: int = 1, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if dilation != 1: raise NotImplementedError( f"Dilation of {dilation} is not supported. " @@ -566,14 +610,14 @@ def max_pool1d( def max_pool2d( - input: np.ndarray, + input: np.ndarray[Any, Any], kernel_size: tuple[int, int], stride: tuple[int, int], *, padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (0, 0), dilation: tuple[int, int] = (1, 1), cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: """Implements torch.nn.functional.max_pool2d in Numpy""" if dilation != (1, 1): @@ -603,10 +647,10 @@ def max_pool2d( def scaled_dot_product_attention( - query: np.ndarray, - key: np.ndarray, - value: np.ndarray, - attn_mask: np.ndarray | None = None, + query: np.ndarray[Any, Any], + key: np.ndarray[Any, Any], + value: np.ndarray[Any, Any], + attn_mask: np.ndarray[Any, Any] | None = None, *, dropout_p: float = 0.0, is_causal: bool = False, @@ -643,20 +687,20 @@ def scaled_dot_product_attention( # Loss funcs def cross_entropy( - input: np.ndarray, - target: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], weights: list[float] | bool, - cutoff: np.ndarray, + cutoff: np.ndarray[Any, Any], *, categorical: bool = True, robust: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: _weights = calculate_cross_entropy_class_weights( input, target, categorical, weights ) write_into_cache(cache, "weights", _weights) - log: partial | Callable = ( + log: partial[np.ndarray[Any, Any]] | Callable[..., np.ndarray[Any, Any]] = ( partial(robust_log, cutoff=cutoff, cache=None) if robust else np.log ) if categorical: @@ -673,16 +717,16 @@ def cross_entropy( def cross_entropy_with_logits( - input: np.ndarray, - target: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], weights: list[float] | bool, - cutoff: np.ndarray, + cutoff: np.ndarray[Any, Any], *, categorical: bool = True, robust: bool = False, cache: CacheType | None = None, -) -> np.ndarray: - log: partial | Callable = ( +) -> np.ndarray[Any, Any]: + log: partial[np.ndarray[Any, Any]] | Callable[..., np.ndarray[Any, Any]] = ( partial(robust_log, cutoff=cutoff, cache=None) if robust else np.log ) _weights = calculate_cross_entropy_class_weights( @@ -705,13 +749,13 @@ def cross_entropy_with_logits( def cross_entropy_with_log_probs( - input: np.ndarray, - target: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], weights: list[float] | bool, *, categorical: bool = True, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: _weights = calculate_cross_entropy_class_weights( input, target, categorical, weights ) @@ -729,33 +773,33 @@ def cross_entropy_with_log_probs( def binary_cross_entropy( - input: np.ndarray, - target: np.ndarray, - cutoff: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + cutoff: np.ndarray[Any, Any], *, pos_weight: bool | float = 1.0, robust: bool = False, cache: CacheType | None = None, -) -> np.ndarray: - log: partial | Callable = ( +) -> np.ndarray[Any, Any]: + log: partial[np.ndarray[Any, Any]] | Callable[..., np.ndarray[Any, Any]] = ( partial(robust_log, cutoff=cutoff, cache=None) if robust else np.log ) if isinstance(pos_weight, bool) and pos_weight: - pos_weight = calculate_binary_class_weight(target) + pos_weight = float(calculate_binary_class_weight(target)) write_into_cache(cache, "pos_weight", pos_weight) return -pos_weight * target * log(input) - (1 - target) * log(1 - input) def binary_cross_entropy_with_logits( - input: np.ndarray, - target: np.ndarray, - cutoff: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + cutoff: np.ndarray[Any, Any], *, pos_weight: bool | float = 1.0, robust: bool = False, cache: CacheType | None = None, -) -> np.ndarray: - log: partial | Callable = ( +) -> np.ndarray[Any, Any]: + log: partial[np.ndarray[Any, Any]] | Callable[..., np.ndarray[Any, Any]] = ( partial(robust_log, cutoff=cutoff, cache=None) if robust else np.log ) @@ -777,35 +821,39 @@ def binary_cross_entropy_with_logits( def quantile_loss( - input: np.ndarray, - target: np.ndarray, - quantile: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + quantile: np.ndarray[Any, Any], cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: error = target - input return np.maximum(quantile * error, (quantile - 1) * error) def hinge_loss( - input: np.ndarray, target: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: base_hinge = 1.0 - target * input write_into_cache(cache, "base_hinge", base_hinge) return np.maximum(0.0, base_hinge) def quad_hinge_loss( - input: np.ndarray, target: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: return hinge_loss(input, target) ** 2 def kl_divergence( - input: np.ndarray, - target: np.ndarray, - cutoff: np.ndarray, + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + cutoff: np.ndarray[Any, Any], cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: log_input1 = robust_log(input, cutoff) log_input2 = robust_log(target, cutoff) partial_result = log_input2 - log_input1 @@ -814,23 +862,31 @@ def kl_divergence( def absolute_error( - input: np.ndarray, target: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + target: np.ndarray[Any, Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: diff = input - target write_into_cache(cache, "diff", diff) return np.abs(diff) def primitive_accuracy( - input1: np.ndarray, input2: np.ndarray, *, cache: CacheType | None = None -) -> np.ndarray: + input1: np.ndarray[Any, Any], + input2: np.ndarray[Any, Any], + *, + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: prediction = np.argmax(input1, axis=1).reshape(input1.shape[0], 1) return np.mean(prediction == input2) def auc_core( - input: np.ndarray, label: np.ndarray, *, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + label: np.ndarray[Any, Any], + *, + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: if input.ndim > 1: raise ValueError(f"Input should be 1D array, but given '{input.ndim}D'") if label.ndim > 1: @@ -860,76 +916,85 @@ def auc_core( return np.stack([tprs, fprs]) -def transposed_diag(input: np.ndarray, *, cache: CacheType | None = None) -> np.ndarray: +def transposed_diag( + input: np.ndarray[Any, Any], *, cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.diag(input)[:, np.newaxis] def broadcast_to( - input: np.ndarray, shape: tuple[int, ...], cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], shape: tuple[int, ...], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.broadcast_to(input, shape) def ones_with_zero_diag( *args, precision: int, cache: CacheType | None = None -) -> np.ndarray: +) -> np.ndarray[Any, Any]: n, m = args output = np.ones((n, m)) - np.eye(n, m) if m is not None else np.ones(n) - np.eye(n) return handle_data_precision(output, precision=precision) -def eye(*args, precision: int, cache: CacheType | None = None) -> np.ndarray: +def eye(*args, precision: int, cache: CacheType | None = None) -> np.ndarray[Any, Any]: return handle_data_precision(np.eye(*args), precision=precision) -def squeeze(input: np.ndarray, *, cache: CacheType | None = None) -> np.ndarray: +def squeeze( + input: np.ndarray[Any, Any], *, cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.squeeze(input) def to_tensor( *input: NestedFloatOrIntOrBoolList, precision: int, cache: CacheType | None = None -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.array(input[0], dtype=get_type(input[0], precision=precision)) -def tensor_to_list(input: np.ndarray, cache: CacheType | None = None): +def tensor_to_list(input: np.ndarray[Any, Any], cache: CacheType | None = None): return input.tolist() def primitive_embedding( - input: np.ndarray, embedding_matrix: np.ndarray, *, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[Any, Any], + embedding_matrix: np.ndarray[Any, Any], + *, + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: return embedding_matrix[input] def where( - cond: np.ndarray, - input1: np.ndarray, - input2: np.ndarray, + cond: np.ndarray[Any, Any], + input1: np.ndarray[Any, Any], + input2: np.ndarray[Any, Any], *, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.where(cond, input1, input2) def concat( - *inputs: np.ndarray, axis: int | None = 0, cache: CacheType | None = None -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], axis: int | None = 0, cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return np.concatenate([np.array(v) for v in inputs], axis=axis) -def arange(*args, precision: int, cache: CacheType | None = None) -> np.ndarray: +def arange( + *args, precision: int, cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return handle_data_precision(np.arange(*args), precision) def flatten( - input: np.ndarray, + input: np.ndarray[Any, Any], *, start_dim: int = 0, end_dim: int = -1, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: """Flattens a Numpy array akin to torch.flatten""" if end_dim == -1 or end_dim == len(input.shape): end_dim = len(input.shape) + 1 @@ -942,16 +1007,20 @@ def flatten( return np.reshape(input, shape) -def stop_gradient(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def stop_gradient( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: return input -def shape(input: np.ndarray, cache: CacheType | None = None) -> tuple[int, ...]: +def shape( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> tuple[int, ...]: return input.shape def size( - input: np.ndarray, + input: np.ndarray[Any, Any], dim: int | tuple[int, ...] | None, cache: CacheType | None = None, ) -> int | tuple[int]: @@ -959,24 +1028,24 @@ def size( return input.size if isinstance(dim, int): return input.shape[dim] - if isinstance(dim, Sequence): - return tuple(input.shape[idx] for idx in dim) else: - raise ValueError(f"Unexpected dim: {dim}") + return tuple(input.shape[idx] for idx in dim) -def norm_modifier(input: np.ndarray, cache: CacheType | None = None) -> np.ndarray: +def norm_modifier( + input: np.ndarray[Any, Any], cache: CacheType | None = None +) -> np.ndarray[Any, Any]: inner_term = ((input - 1.0) % 8) / 8 - 0.5 write_into_cache(cache, "inner_term", inner_term) return 4.0 * (1.0 - 2.0 * np.abs(inner_term)) + 1.0 def distance_matrix( - left: np.ndarray, - right: np.ndarray, - norm: np.ndarray, + left: np.ndarray[Any, Any], + right: np.ndarray[Any, Any], + norm: np.ndarray[Any, Any], cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: diffs = left[:, None, :] - right[None, :, :] write_into_cache(cache, "diffs", diffs) abs_diffs = np.abs(diffs) @@ -987,8 +1056,11 @@ def distance_matrix( def polynomial_features( - input: np.ndarray, *, degree: int = 2, cache: CacheType | None = None -) -> np.ndarray: + input: np.ndarray[tuple[int, int], Any], + *, + degree: int = 2, + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: samples, dims = input.shape identity = np.eye(dims + 1, dims + 1, dtype=int) data = np.hstack((np.ones((samples, 1), dtype=input.dtype), input)) @@ -1005,16 +1077,16 @@ def polynomial_features( def tsne_p_joint( - squared_distances: np.ndarray, - target_perplexity: np.ndarray, - threshold: np.ndarray, + squared_distances: np.ndarray[Any, Any], + target_perplexity: np.ndarray[Any, Any], + threshold: np.ndarray[Any, Any], *, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: """Given a data matrix X, gives joint probabilities matrix. Parameters ---------- - squared_distances : np.ndarray + squared_distances : np.ndarray[Any, Any] Square of distance matrix of Input data. target_perplexity : float Desired perplexity value. @@ -1038,8 +1110,8 @@ def tsne_p_joint( def cholesky( - input1: np.ndarray, *, cache: CacheType | None = None -) -> np.ndarray | None: + input1: np.ndarray[Any, Any], *, cache: CacheType | None = None +) -> np.ndarray[Any, Any] | None: try: return np.linalg.cholesky(input1) except np.linalg.LinAlgError as e: @@ -1048,12 +1120,12 @@ def cholesky( def gpr_alpha( - label_mu_diff: np.ndarray, - L: np.ndarray, - K_term: np.ndarray, + label_mu_diff: np.ndarray[Any, Any], + L: np.ndarray[Any, Any], + K_term: np.ndarray[Any, Any], *, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if L is not None: alpha = slin.solve_triangular( L.T, @@ -1066,12 +1138,12 @@ def gpr_alpha( def eigvalsh( - K_term: np.ndarray, - L: np.ndarray, + K_term: np.ndarray[Any, Any], + L: np.ndarray[Any, Any], threshold: float, *, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if L is not None: return np.diag(L) else: @@ -1079,8 +1151,12 @@ def eigvalsh( def gpr_v_outer( - K: np.ndarray, K_term: np.ndarray, L: np.ndarray, *, cache: CacheType | None = None -) -> np.ndarray: + K: np.ndarray[Any, Any], + K_term: np.ndarray[Any, Any], + L: np.ndarray[Any, Any], + *, + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: if L is not None: v = slin.solve_triangular(L, K, lower=True) v_outer = v.T @ v @@ -1089,37 +1165,41 @@ def gpr_v_outer( return v_outer -def isnan(input: np.ndarray, *, cache: CacheType | None = None): +def isnan(input: np.ndarray[Any, Any], *, cache: CacheType | None = None): return np.isnan(input) def nan_to_num( - input, + input: np.ndarray[Any, Any], nan: float, posinf: float | None, neginf: float | None, *, cache: CacheType | None = None, ): - return np.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf) #  type: ignore + return np.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf) -def astype(input: np.ndarray, dtype: core.Dtype | int) -> np.ndarray: +def astype( + input: np.ndarray[Any, Any], dtype: core.Dtype | int +) -> np.ndarray[Any, Any]: return handle_data_dtype(input, dtype) -def dtype(input: np.ndarray) -> core.Dtype: +def dtype(input: np.ndarray[Any, Any]) -> core.Dtype: return getattr(core, str(input.dtype)) def logical_xor( - left: np.ndarray, right: np.ndarray, cache: CacheType | None = None -) -> np.ndarray: + left: np.ndarray[Any, Any], + right: np.ndarray[Any, Any], + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: return np.logical_xor(left, right) def split( - input: np.ndarray, + input: np.ndarray[Any, Any], split_size: int | list[int], axis: int = 0, cache: CacheType | None = None, @@ -1128,7 +1208,7 @@ def split( def pad( - input: np.ndarray, + input: np.ndarray[Any, Any], pad_width: tuple[tuple[int, int], ...], cache: CacheType | None = None, ): diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py index df966fa3..55a0f7c3 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py @@ -14,6 +14,7 @@ import itertools from itertools import zip_longest +from typing import Any import numpy as np import scipy.linalg as slin # type: ignore[import-untyped] @@ -113,11 +114,11 @@ def matrix_multiplication_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType | None, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) if idx == 0: return output_gradient @ np.swapaxes(inputs[1], -1, -2) @@ -128,22 +129,22 @@ def matrix_multiplication_grad( def multiplication_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) input1, input2 = inputs return [input2, input1][idx] * output_gradient def divide_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: #  type: ignore + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) input1, input2 = inputs if idx == 0: @@ -155,31 +156,31 @@ def divide_grad( def add_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient def subtract_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return [1, -1][idx] * output_gradient def squared_error_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) input, target = inputs grad = 2 * (input - target) * output_gradient @@ -187,24 +188,24 @@ def squared_error_grad( def absolute_error_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) signs = np.sign(cache["diff"]) * output_gradient return [1, -1][idx] * signs def cross_entropy_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], categorical: bool = True, robust: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1, 2]) input, target, _, cutoff = inputs if categorical: @@ -238,13 +239,13 @@ def cross_entropy_grad( def cross_entropy_with_logits_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], categorical: bool = True, robust: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, target, _, cutoff = inputs grad = softmax(input, axis=1) @@ -266,12 +267,12 @@ def cross_entropy_with_logits_grad( def cross_entropy_with_log_probs_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], categorical: bool = True, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: input, target, _ = inputs if categorical: @@ -288,13 +289,13 @@ def cross_entropy_with_log_probs_grad( def binary_cross_entropy_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], pos_weight: bool | float = 1.0, robust: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1, 2]) ( input, @@ -315,13 +316,13 @@ def binary_cross_entropy_grad( def binary_cross_entropy_with_logits_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], pos_weight: bool | float = 1.0, robust: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, target, *_ = inputs @@ -338,11 +339,11 @@ def binary_cross_entropy_with_logits_grad( def quantile_loss_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[2]) input, target, quantile = inputs dloss = np.sign(target - input) @@ -354,11 +355,11 @@ def quantile_loss_grad( def hinge_loss_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) # TODO: Discuss -> Since we may return original output to the user, \ @@ -370,11 +371,11 @@ def hinge_loss_grad( def quad_hinge_loss_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, target = inputs return ( @@ -383,11 +384,11 @@ def quad_hinge_loss_grad( def kl_divergence_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[2]) input, target, cutoff = inputs grad = ( @@ -399,11 +400,11 @@ def kl_divergence_grad( def power_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) input, target = inputs if idx == 0: @@ -418,51 +419,51 @@ def power_grad( def exp_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return cache["output"] * output_gradient def sqrt_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return (1 / 2) * (1 / cache["output"]) * output_gradient def sin_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return np.cos(inputs[0]) * output_gradient def cos_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return -np.sin(inputs[0]) * output_gradient def robust_sqrt_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, cutoff = inputs inds = np.abs(input) < cutoff @@ -476,11 +477,11 @@ def robust_sqrt_grad( # undefined points (log(0) = -inf in this case), # further testing should be done about performance. def robust_log_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType | None, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, cutoff = inputs negative_inds = input < 0.0 @@ -498,11 +499,11 @@ def robust_log_grad( # undefined points (f(0) = inf in this case), # futher testing should be done. def stable_reciprocal_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, cutoff = inputs inds = np.abs(input) < cutoff @@ -513,43 +514,43 @@ def stable_reciprocal_grad( def cartesian_diff_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return [1, -1][idx] * np.sum(output_gradient, axis=1 - idx) def abs_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) # sign[sign == 0] = 1 # NOTE: JAX returns 1.0 for gradient at this point!!! return np.sign(inputs[0]) * output_gradient def sign_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return np.zeros_like(inputs[0]) def concat_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, - axis=0, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], + axis: int | None = 0, +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[-1]) # Since last element of args is axis, exclude it from # gradient calculation. @@ -568,13 +569,13 @@ def concat_grad( def reduce_mean_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int | tuple[int, ...] | None = None, keepdim: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) # TODO: Update gradient formula (same as Sum). (input,) = inputs @@ -595,13 +596,13 @@ def reduce_mean_grad( def reduce_sum_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType | None, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int | tuple[int, ...] | None = None, keepdim: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input_shape = inputs[0].shape if axis is None: @@ -613,13 +614,13 @@ def reduce_sum_grad( def reduce_max_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType | None, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int | tuple[int, ...] | None = None, keepdim: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) (input,) = inputs input_shape = input.shape @@ -640,13 +641,13 @@ def reduce_max_grad( def reduce_min_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int | tuple[int, ...] | None = None, keepdim: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) (input,) = inputs # Expand dimensions of output gradient to match the input shape @@ -666,13 +667,13 @@ def reduce_min_grad( def reduce_prod_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int | tuple[int, ...] | None = None, keepdim: bool = False, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: (input,) = inputs _axis: tuple[int, ...] @@ -707,21 +708,21 @@ def reduce_prod_grad( def buffer_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient def relu_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) (input,) = inputs # TODO: Check if copy is necessary. @@ -731,7 +732,12 @@ def relu_grad( return output_gradient * inp_copy -def leaky_relu_grad(output_gradient, cache, idx, *inputs): +def leaky_relu_grad( + output_gradient: np.ndarray[Any, Any], + cache: CacheType, + idx: int, + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) if idx == 0: # TODO: Check if copy is necessary. @@ -740,48 +746,48 @@ def leaky_relu_grad(output_gradient, cache, idx, *inputs): inp_copy[inp_copy > 0.0] = 1.0 inp_copy[inp_copy <= 0.0] = slope return output_gradient * inp_copy - elif idx == 1: + else: return output_gradient * np.minimum(0.0, inputs[0]) def tanh_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient * (1.0 - cache["output"] ** 2) def sigmoid_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) sig = cache["output"] return output_gradient * (sig * (1 - sig)) def softplus_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) - sig = sigmoid(inputs[0]) + sig: np.ndarray[Any, Any] = sigmoid(inputs[0]) return output_gradient * sig def gelu_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: input, *_ = inputs s = input / np.sqrt(2) erf_prime = lambda x: (2 / np.sqrt(np.pi)) * np.exp(-(x**2)) # noqa: E731 @@ -790,34 +796,20 @@ def gelu_grad( def stop_gradient_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: return np.zeros_like(output_gradient) -# def tensor_slice_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# verify_shapes(inputs, idx) -# input1, input2 = inputs -# if idx == 0: -# grad = np.zeros_like(input1) -# grad[:input2.shape[0], ...] = output_gradient -# return grad -# elif idx == 1: -# return np.zeros_like(input2) - - def tensor_slice_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1, 2, 3, 4]) input, start, stop, step = inputs grad = np.zeros_like(input) @@ -826,11 +818,11 @@ def tensor_slice_grad( def tensor_item_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, index = inputs grad = np.zeros_like(input) @@ -839,22 +831,22 @@ def tensor_item_grad( def permute_tensor_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) indices = inputs[1] return output_gradient[np.argsort(indices)] def transpose_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) axes = inputs[1] return ( @@ -865,31 +857,31 @@ def transpose_grad( def square_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient * 2 * inputs[0] def conv1d_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], stride: int = 1, padding: tuple[int, int] = (1, 1), dilation: int = 1, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[2, 3, 4]) input1, input2 = inputs n, c, w = input1.shape _, _, w_k = input2.shape out_w = (w - w_k + sum(padding)) // stride + 1 if idx == 0: - _output_gradient = np.zeros( + _output_gradient: np.ndarray[Any, Any] = np.zeros( ( output_gradient.shape[0], output_gradient.shape[1], @@ -926,14 +918,14 @@ def conv1d_grad( def conv1d_bias_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], stride: int = 1, padding: tuple[int, int] = (1, 1), dilation: int = 1, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[3, 4, 5]) if idx < 2: return conv1d_grad( @@ -952,14 +944,14 @@ def conv1d_bias_grad( def conv2d_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[2, 3, 4]) input1, input2 = inputs @@ -1026,14 +1018,14 @@ def conv2d_grad( def conv2d_bias_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[3, 4, 5]) if idx < 2: return conv2d_grad( @@ -1052,27 +1044,27 @@ def conv2d_bias_grad( def flatten_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], start_dim: int = 0, end_dim: int = -1, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1, 2]) return output_gradient.reshape(*inputs[0].shape) def max_pool1d_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], kernel_size: int, stride: int, padding: tuple[int, int] = (0, 0), dilation: int = 1, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if idx == 0: (input,) = inputs *_, w = input.shape @@ -1099,15 +1091,15 @@ def max_pool1d_grad( def max_pool2d_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], kernel_size: tuple[int, int], stride: tuple[int, int], padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (0, 0), dilation: tuple[int, int] = (1, 1), -) -> np.ndarray: +) -> np.ndarray[Any, Any]: if idx == 0: (input,) = inputs @@ -1157,12 +1149,12 @@ def max_pool2d_grad( def softmax_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int = -1, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) output = cache["output"] axis = cache["axis"] @@ -1172,11 +1164,11 @@ def softmax_grad( def robust_power_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[2]) input1, input2, threshold = inputs result_shape = np.broadcast_shapes(input1.shape, input2.shape) @@ -1212,11 +1204,11 @@ def robust_power_grad( def norm_modifier_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) inner_term = cache["inner_term"] inner_term_item = inner_term if isinstance(inner_term, float) else inner_term.item() @@ -1231,11 +1223,11 @@ def norm_modifier_grad( def distance_matrix_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) norm = inputs[2] abs_diffs = cache["abs_diffs"] @@ -1258,12 +1250,12 @@ def distance_matrix_grad( def polynomial_features_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], degree: int = 2, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) samples, dims = inputs[0].shape powers = cache["powers"] @@ -1291,61 +1283,34 @@ def polynomial_features_grad( ) -# def ones_with_zero_diag_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# verify_shapes(inputs, idx) -# return np.zeros_like(inputs[0]) - -# def eye_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# # TODO: Remove gradient formula!!! -# return np.zeros_like(inputs[0]) - - def cholesky_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) # TODO: Implement cholesky gradient! raise NotImplementedError("Implement gradient of Cholesky Factorization!") def gpr_alpha_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) label_mu_diff, L, K_term = inputs raise NotImplementedError() - # if L is not None: - # l_inv = np.linalg.inv(L) - # diff_grad = l_inv.T @ l_inv @ output_gradient - # # (np.kron(l_inv.T, (l_inv.T @ l_inv @ label_mu_diff)) \ - # @ output_gradients.reshape(-1,1)).reshape(2,2) - # # np.tril(-l_inv.T @ ((l_inv @ np.array(label_mu_diff)) \ - # @ np.array([[3.0, 2]])) @ l_inv.T) - # else: - # # IMPLEMENT - # return - # return None - def eigvalsh_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[2]) k_term, l_, threshold = inputs if idx == 0: @@ -1370,11 +1335,11 @@ def eigvalsh_grad( def gpr_v_outer_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) k_, _, l_ = inputs if idx == 0: @@ -1412,11 +1377,11 @@ def gpr_v_outer_grad( def transposed_diag_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) (input,) = inputs eye_mat = np.eye(input.shape[0]) @@ -1424,31 +1389,31 @@ def transposed_diag_grad( def log_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType | None, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient * (1 / inputs[0]) def squeeze_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient.reshape(inputs[0].shape) def broadcast_to_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: input, shape = inputs[0], inputs[1] input_shape = input.shape bcast_indexes = [] @@ -1468,24 +1433,24 @@ def broadcast_to_grad( def swapaxes_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: *_, axis1, axis2 = inputs return output_gradient.swapaxes(axis1, axis2) def variance_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], axis: int | tuple[int, ...] | None = None, keepdim: bool = False, correction: float = 0.0, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1, 2]) (input,) = inputs shape = input.shape @@ -1507,33 +1472,23 @@ def variance_grad( ) -# def shape_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# if idx == 0: -# return np.zeros_like(inputs[0]) -# else: -# return gradient_exception(idx) - - def reshape_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx, non_differentiables=[1]) input, _ = inputs return output_gradient.reshape(input.shape) def where_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: cond, *_ = inputs if idx == 1: @@ -1545,11 +1500,11 @@ def where_grad( def primitive_embedding_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: # TODO: Check this function, if it works properly add tests. verify_shapes(inputs, idx, non_differentiables=[0]) if idx == 1: @@ -1561,10 +1516,10 @@ def primitive_embedding_grad( def scaled_dot_product_attention_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], dropout_p: float = 0.0, is_causal: bool = False, scale: float | int | None = None, @@ -1614,21 +1569,21 @@ def scaled_dot_product_attention_grad( def isnan_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return output_gradient def nan_to_num_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) if idx == 0: return ~(np.isnan(inputs[0]) | np.isinf(inputs[0])) * output_gradient @@ -1637,39 +1592,11 @@ def nan_to_num_grad( ) -# def index_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# if idx == 0: -# input, index = inputs -# grad = np.zeros(input.shape) -# grad[index] = output_gradient -# return grad -# else: -# return gradient_exception(idx) - -# def sequence_slice_grad(output_gradient: np.ndarray, -# cache: CacheType, -# idx: int, -# *inputs: np.ndarray, -> np.ndarray: -# if idx == 0: -# input, *_ = inputs -# start_idx = cache["start_idx"] -# stop_idx = cache["stop_idx"] -# step_size = cache["step_size"] -# grad = np.zeros_like(input) -# grad[start_idx: stop_idx: step_size] = output_gradient -# return gradf -# else: -# return gradient_exception(idx) - - def split_grad( - output_gradient: list[np.ndarray], + output_gradient: list[np.ndarray[Any, Any]], cache: CacheType, idx: int, - *inputs: np.ndarray, + *inputs: np.ndarray[Any, Any], ): input, split_size, axis = inputs input_shape = input.shape @@ -1700,11 +1627,11 @@ def split_grad( def pad_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: input, padding = inputs slices = tuple( @@ -1715,11 +1642,11 @@ def pad_grad( def minus_grad( - output_gradient: np.ndarray, + output_gradient: np.ndarray[Any, Any], cache: CacheType, idx: int, - *inputs: np.ndarray, -) -> np.ndarray: + *inputs: np.ndarray[Any, Any], +) -> np.ndarray[Any, Any]: verify_shapes(inputs, idx) return -output_gradient diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index 758b9a87..3cbab347 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from functools import partial -from typing import Any, overload +from typing import Any import numpy as np @@ -42,59 +42,15 @@ CacheType = dict[str, Any] -@overload -def write_into_cache( - cache: CacheType | None, - key: str, - value: tuple[np.ndarray, ...], - *, - constant: bool = False, - func: Callable | None = None, -) -> tuple[np.ndarray, ...]: ... - - -@overload -def write_into_cache( - cache: CacheType | None, - key: str, - value: tuple[Any, ...], - *, - constant: bool = False, - func: Callable | None = None, -) -> tuple[Any, ...]: ... - - -@overload -def write_into_cache( - cache: CacheType | None, - key: str, - value: int | float, - *, - constant: bool = False, - func: Callable | None = None, -) -> int | float: ... - - -@overload -def write_into_cache( - cache: CacheType | None, - key: str, - value: np.ndarray, - *, - constant: bool = False, - func: Callable | None = None, -) -> np.ndarray: ... - - # TODO: resolve its types -def write_into_cache( +def write_into_cache[T: np.ndarray[Any, Any] | tuple[Any, ...] | int | float]( cache: CacheType | None, key: str, - value: np.ndarray | tuple[np.ndarray, ...] | int | float, + value: T, *, constant: bool = False, - func: Callable | None = None, -) -> np.ndarray | tuple[np.ndarray, ...] | int | float: + func: Callable[..., Any] | None = None, +) -> T: """Writes key-value pair into the provided cache if there exists. If func given it is called with given value and then result is written into cache for given key. If constant @@ -125,10 +81,7 @@ def write_into_cache( result = value cache[key] = result elif not (constant and key in cache): - if func is None: - result = value - else: - result = func(*value) if isinstance(value, tuple | list) else func(value) + result = func(*value) if isinstance(value, tuple | list) else func(value) cache[key] = result else: result = cache[key] @@ -137,12 +90,12 @@ def write_into_cache( def get_submatrices1d( - input: np.ndarray, - output_size, - kernel_width_size, + input: np.ndarray[Any, Any], + output_size: tuple[int, ...], + kernel_width_size: int, padding: tuple[int, int] = (0, 0), - stride=1, - dilate=0, + stride: int = 1, + dilate: int = 0, ): # TODO: Return type??? working_input = input working_pad = padding @@ -172,13 +125,13 @@ def get_submatrices1d( # TODO: padding, strinde and dilation must be int or tuple. def get_submatrices2d( - input: np.ndarray, - output_size, - kernel_height_size, - kernel_width_size, + input: np.ndarray[Any, Any], + output_size: tuple[int, ...], + kernel_height_size: int, + kernel_width_size: int, padding: tuple[tuple[int, int], tuple[int, int]] = ((0, 0), (0, 0)), - stride=1, - dilate=0, + stride: int = 1, + dilate: int = 0, ): # TODO: Return type??? working_input = input working_pad = padding @@ -220,10 +173,10 @@ def get_submatrices2d( def tsne_softmax( - input_tensor: np.ndarray, + input_tensor: np.ndarray[Any, Any], diag_zero: bool = False, zero_index: int | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: input_tensor = input_tensor - np.max(input_tensor, axis=1, keepdims=True) e = np.exp(input_tensor) if zero_index is None: @@ -236,8 +189,10 @@ def tsne_softmax( def calc_prob_matrix( - negative_dist_sq: np.ndarray, sigmas: np.ndarray, zero_index=None -) -> np.ndarray: + negative_dist_sq: np.ndarray[Any, Any], + sigmas: np.ndarray[Any, Any], + zero_index: int | None = None, +) -> np.ndarray[Any, Any]: """Convert a distances matrix to a matrix of probabilities. Parameters ---------- @@ -255,7 +210,9 @@ def calc_prob_matrix( """ two_sig_sq = 2.0 * np.square(sigmas.reshape((-1, 1))) if two_sig_sq.shape[0] == 1: - dist_sig = [negative_dist_sq / two_sig_sq, 0][np.squeeze(two_sig_sq) == 0.0] + dist_sig = [negative_dist_sq / two_sig_sq, np.array(0.0)][ + np.squeeze(two_sig_sq) == 0.0 + ] else: mask = two_sig_sq == 0.0 dist_sig = np.zeros_like(negative_dist_sq) @@ -264,11 +221,11 @@ def calc_prob_matrix( def perplexity_fn( - negative_dist_sq: np.ndarray, - sigmas: np.ndarray, + negative_dist_sq: np.ndarray[Any, Any], + sigmas: np.ndarray[Any, Any], zero_index: int, threshold: float, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: """Wrapper function for quick calculation of perplexity over a distance matrix. Parameters @@ -293,8 +250,8 @@ def perplexity_fn( def find_optimal_sigmas( - negative_dist_sq: np.ndarray, target_perplexity: int, threshold: float -) -> np.ndarray: + negative_dist_sq: np.ndarray[Any, Any], target_perplexity: int, threshold: float +) -> np.ndarray[Any, Any]: """For each row of distances matrix, find sigma that results in target perplexity for that role. Parameters @@ -308,10 +265,10 @@ def find_optimal_sigmas( np.ndarray Returns optimal sigma values. """ - sigmas = [] + sigmas: list[float] = [] # Make fn that returns perplexity of this row given sigma - def eval_fn(sigma, i): + def eval_fn(sigma: float, i: int): return perplexity_fn(negative_dist_sq[i, :], np.array(sigma), i, threshold) # For each row of the matrix (each point in our dataset) @@ -329,16 +286,18 @@ def eval_fn(sigma, i): # Shared or reused intermediate value calculator functions for primitive models -def find_label_indices(input_array: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def find_label_indices( + input_array: np.ndarray[Any, Any], +) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: return (np.arange(len(input_array)), input_array.T) def calc_input_slices( - output_gradient: np.ndarray, axis: int, *args: np.ndarray + output_gradient: np.ndarray[Any, Any], axis: int | None, *args: np.ndarray[Any, Any] ) -> dict[str, tuple[slice, ...]]: # Calculates the slices of output_gradient corresponding to # inputs. - slices = {} + slices: dict[str, tuple[slice, ...]] = {} base_slices = [slice(None)] * output_gradient.ndim finish = 0 for idx, arg in enumerate(args): @@ -362,7 +321,13 @@ def handle_dtype(dtype: Any) -> Any: raise TypeError(f"Provided data type '{dtype}' not understood") from err -def creation_fn_wrapper(*args, fn: Callable, precision: int, dtype=None, **kwargs): +def creation_fn_wrapper( + *args: Any, + fn: Callable[..., np.ndarray[Any, Any]], + precision: int, + dtype: core.Dtype | np.dtype[Any] | None = None, + **kwargs: Any, +): if dtype is not None: dtype = handle_dtype(dtype) data = fn(*args, dtype=dtype, **kwargs) @@ -373,7 +338,12 @@ def creation_fn_wrapper(*args, fn: Callable, precision: int, dtype=None, **kwarg def conversion_fn_wrapper( - data, *args, fn: Callable, precision: int, dtype=None, **kwargs + data: Any, + *args: Any, + fn: Callable[..., np.ndarray[Any, Any]], + precision: int, + dtype: np.dtype[Any] | None = None, + **kwargs: Any, ): if dtype is not None: dtype = handle_dtype(dtype) @@ -388,7 +358,9 @@ def conversion_fn_wrapper( return _data -def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: +def handle_data_precision( + data: np.ndarray[Any, Any], precision: int +) -> np.ndarray[Any, Any]: if isinstance(data, float | int): return data _dtype = data.dtype @@ -405,20 +377,26 @@ def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: return data -def handle_data_dtype(data: np.ndarray, dtype: core.Dtype | int) -> np.ndarray: - if isinstance(dtype, int): - dtype = core.Dtype(dtype) +def handle_data_dtype( + data: np.ndarray[Any, Any], dtype: core.Dtype | int +) -> np.ndarray[Any, Any]: + dtype = core.Dtype(dtype) if data.dtype != dtype_map[dtype.name]: return data.astype(dtype_map[dtype.name]) return data -def make_array(input: int | float | ArrayType, precision): +def make_array(input: int | float | np.ndarray[Any, Any], precision: int): return handle_data_precision(np.array(input), precision=precision) -def accumulate_grads(gradient: np.ndarray, input: np.ndarray, cache, idx): +def accumulate_grads( + gradient: np.ndarray[Any, Any], + input: np.ndarray[Any, Any], + cache: CacheType | None, + idx: int, +) -> np.ndarray[Any, Any]: axes = write_into_cache( cache, "accumulate" + str(idx), @@ -433,7 +411,9 @@ def accumulate_grads(gradient: np.ndarray, input: np.ndarray, cache, idx): return gradient -def _accumulate_grads_helper(grad_shape, input_shape): +def _accumulate_grads_helper( + grad_shape: tuple[int, ...], input_shape: tuple[int, ...] +) -> tuple[int, ...]: # TODO: Refactor the code below rev_grad = list(reversed(grad_shape)) axes = tuple([i for i in range(len(grad_shape) - len(input_shape))]) @@ -447,7 +427,9 @@ def _accumulate_grads_helper(grad_shape, input_shape): return axes -def log_sigmoid(input: np.ndarray, log: Callable, robust: bool): +def log_sigmoid( + input: np.ndarray[Any, Any], log: Callable[..., np.ndarray[Any, Any]], robust: bool +): min = np.minimum(0, input) input = np.exp(-np.abs(input)) if not robust: @@ -455,20 +437,27 @@ def log_sigmoid(input: np.ndarray, log: Callable, robust: bool): return min - log(1 + input) -def log_softmax(input: np.ndarray, log: Callable, robust: bool, axis: int = -1): +def log_softmax( + input: np.ndarray[Any, Any], + log: Callable[..., np.ndarray[Any, Any]], + robust: bool, + axis: int = -1, +): return input - log(np.exp(input).sum(axis=axis, keepdims=True)) -def calculate_binary_class_weight(labels): +def calculate_binary_class_weight(labels: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return (1 - labels.mean()) / labels.mean() -def calculate_categorical_class_weight(labels, num_classes: int): +def calculate_categorical_class_weight( + labels: np.ndarray[Any, Any], num_classes: int +) -> np.ndarray[Any, Any]: one_hot = np.eye(num_classes)[labels] return calculate_class_weight(one_hot) -def calculate_class_weight(labels): +def calculate_class_weight(labels: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return ( (1 / labels.sum(axis=tuple(i for i in range(labels.ndim) if i != 1))) * labels.sum() @@ -477,12 +466,12 @@ def calculate_class_weight(labels): def calculate_cross_entropy_class_weights( - input: np.ndarray, - labels: np.ndarray, + input: np.ndarray[Any, Any], + labels: np.ndarray[Any, Any], is_categorical: bool, weights: bool | list[float], -): - _weights = None +) -> np.ndarray[Any, Any]: + _weights: np.ndarray[Any, Any] if isinstance(weights, bool): if is_categorical: _weights = ( @@ -509,7 +498,10 @@ def calculate_cross_entropy_class_weights( return _weights -def get_type(input: int | float | bool | Sequence, precision: int): +def get_type( + input: int | float | bool | Sequence[int | float | bool | Sequence[Any]], + precision: int, +): type = find_dominant_type(input).__name__ if type == "bool": return np.bool_ @@ -518,7 +510,9 @@ def get_type(input: int | float | bool | Sequence, precision: int): def verify_shapes( - inputs: tuple[np.ndarray, ...], idx: int, non_differentiables=None + inputs: tuple[np.ndarray[Any, Any], ...], + idx: int, + non_differentiables: Iterable[int] | None = None, ) -> None: if idx >= len(inputs): raise Exception(f"Gradient is not defined for the input at index {idx}!") diff --git a/mithril/core.py b/mithril/core.py index ff057a38..a81b3d05 100644 --- a/mithril/core.py +++ b/mithril/core.py @@ -128,7 +128,9 @@ class Dtype(enum.IntEnum): # noqa N801 pass -DataType = TypeVar("DataType", "ndarray", "Array", "Tensor", "PyArray", "array") +DataType = TypeVar( + "DataType", "ndarray[Any, Any]", "Array", "Tensor", "PyArray", "array" +) class GenericDataType(Generic[DataType]): diff --git a/mithril/utils/utils.py b/mithril/utils/utils.py index 51281c79..8a51a644 100755 --- a/mithril/utils/utils.py +++ b/mithril/utils/utils.py @@ -466,7 +466,7 @@ def find_boundary_point( def binary_search( - eval_fn: Callable, + eval_fn: Callable[[Any], Any], target: DataType, *, max_it: int = 1000, @@ -499,7 +499,7 @@ def binary_search( """ it = 0 - def midpoint(a, b): + def midpoint(a: Any, b: Any) -> Any: return (a + b) / 2 while ( diff --git a/tests/scripts/test_parallel.py b/tests/scripts/test_parallel.py index 9bb6d3c6..599bbd9b 100644 --- a/tests/scripts/test_parallel.py +++ b/tests/scripts/test_parallel.py @@ -128,7 +128,7 @@ def test_torch_shared_cyclic_queue_4(): writer.write( opcode1=1, opcode2=4, - args=[5, 5, (4, 2)], + args=(5, 5, (4, 2)), kwargs={"input": "something", "left": (1, 2, (3, 4))}, ) @@ -141,7 +141,7 @@ def test_torch_shared_cyclic_queue_4(): assert reader.read(rank=1) == ( 1, 4, - [5, 5, (4, 2)], + (5, 5, (4, 2)), {"input": "something", "left": (1, 2, (3, 4))}, ) assert instruction[0] == "1" @@ -792,7 +792,7 @@ def test_torch_parallel_error_12(): model = Model() model += Linear(256)(input="input", w="w", b="b") with pytest.raises(AssertionError) as e: - mithril.TorchBackend(device_mesh=2) + mithril.TorchBackend(device_mesh=2) # type: ignore assert str(e.value) == "device_mesh must be a tuple or None." diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 535de469..99a52e97 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pickle import platform import re From 859a02110cade7d8c851912d22900ba80d604562 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Tue, 17 Dec 2024 16:53:22 +0300 Subject: [PATCH 21/26] backend is typed --- benchmarks/speed_benchmarks/speed_helper.py | 2 +- examples/gpt/run_sample.py | 3 +- mithril/backends/backend.py | 32 +-- .../with_autograd/jax_backend/backend.py | 78 +++---- .../with_autograd/jax_backend/utils.py | 3 +- .../with_autograd/mlx_backend/backend.py | 90 +++----- .../with_autograd/mlx_backend/utils.py | 5 +- .../with_autograd/torch_backend/backend.py | 68 +++--- .../with_autograd/torch_backend/stensor.py | 9 +- .../with_autograd/torch_backend/utils.py | 3 +- .../with_manualgrad/c_backend/backend.py | 4 +- .../with_manualgrad/c_backend/src/array.py | 4 +- .../with_manualgrad/common_primitives.py | 51 ++-- .../with_manualgrad/numpy_backend/backend.py | 217 +++++++++--------- .../with_manualgrad/numpy_backend/ops.py | 8 +- .../with_manualgrad/numpy_backend/utils.py | 3 +- mithril/framework/codegen/python_gen.py | 3 +- mithril/framework/physical/model.py | 4 +- mithril/utils/utils.py | 2 +- tests/scripts/helper.py | 4 +- tests/scripts/test_all_models.py | 18 +- tests/scripts/test_backend_fns.py | 4 +- tests/scripts/test_constant_inputs.py | 2 +- .../test_randomized_models_all_backends.py | 29 +-- tests/scripts/test_scripts.py | 6 +- 25 files changed, 333 insertions(+), 319 deletions(-) diff --git a/benchmarks/speed_benchmarks/speed_helper.py b/benchmarks/speed_benchmarks/speed_helper.py index 2d762d39..20e769a0 100644 --- a/benchmarks/speed_benchmarks/speed_helper.py +++ b/benchmarks/speed_benchmarks/speed_helper.py @@ -129,7 +129,7 @@ def measure_time_and_grads_mithril( grads = model.evaluate_gradients(trainable_params, data) - if model.backend.type == "mlx": + if model.backend.backend_type == "mlx": mx.eval(grads) trainable_params = { key: value - lr * grads[key] for key, value in trainable_params.items() diff --git a/examples/gpt/run_sample.py b/examples/gpt/run_sample.py index 231611f0..7a6fe6f0 100644 --- a/examples/gpt/run_sample.py +++ b/examples/gpt/run_sample.py @@ -16,6 +16,7 @@ import sys import warnings from collections.abc import Callable +from typing import Any import tiktoken from model import create_gpt @@ -136,7 +137,7 @@ def get_weights(logical_model: Model, compiled_model: PhysicalModel, backend: Ba def generate( - model: PhysicalModel, + model: PhysicalModel[Any], block_size: int, weights: dict[str, ml.DataType], idx: ml.DataType, diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index a6f672bb..d2562b72 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -32,7 +32,7 @@ class Backend(ABC, Generic[DataType]): """Base class for backend implementations in the Mithril library.""" - type = "" + backend_type = "" device_type = None supported_precisions = [16, 32, 64] is_installed = True @@ -80,7 +80,7 @@ def e(self): def is_manualgrad(self) -> bool: raise NotImplementedError("is_manualgrad is not implemented") - def get_backend_array_type(self): # noqa: B902 + def get_backend_array_type(self) -> type[DataType]: raise NotImplementedError("get_backend_array_type is not implemented") @staticmethod @@ -88,7 +88,7 @@ def register_primitive(fn: Callable[..., Any]) -> None: raise NotImplementedError("register_primitive is not implemented!") @abstractmethod - def set_seed(self, seed: int): + def set_seed(self, seed: int) -> None: raise NotImplementedError( "set_seed function must be overriden for every backend individually!" ) @@ -105,6 +105,9 @@ def empty_cache(self): # noqa: B027 pass # print("Warning: empty_cache is not supported!") + # TODO: Fix types in cast function when python + # adds Higher-Kinded TypeVar support. + # https://github.com/python/typing/issues/548#issuecomment-1193345123 def cast(self, value: Any) -> Any: # Simply casts given value to the backend's precision. # If type of value is not int or float, returns the @@ -143,7 +146,7 @@ def arange( dtype: core.Dtype | None = None, ) -> DataType: ... - def arange(self, *args: int | float, **kwargs) -> DataType: + def arange(self, *args: int | float, **kwargs: Any) -> DataType: raise NotImplementedError("arange is not implemented!") def flatten( @@ -257,7 +260,7 @@ def isnan(self, input: DataType) -> DataType: """ raise NotImplementedError("isnan is not implemented!") - def array(self, data: Any, *, dtype: core.Dtype | None = None) -> DataType: + def array(self, input: Any, *, dtype: core.Dtype | None = None) -> DataType: """Returns a backend array on speficied device by copying `data`. Parameters @@ -836,7 +839,7 @@ def multinomial( ) -> DataType: raise NotImplementedError("multinomial is not implemented!") - def jit(self, fn: Callable) -> Callable: + def jit[T: Any](self, fn: Callable[..., T]) -> Callable[..., T]: """ Just-in-time compile the given function. @@ -980,7 +983,7 @@ def vjp( """ raise NotImplementedError("vjp is not implemented!") - def vmap(self, fn: Callable) -> Callable: + def vmap[T: Callable[..., Any]](self, fn: T) -> T: """ Vectorize the given function. @@ -1050,7 +1053,7 @@ def __init__(self, device_mesh: tuple[int, ...] | None) -> None: self._raw_device_mesh = device_mesh self.n_devices = math.prod(device_mesh) if device_mesh is not None else 1 - self._parallel_manager: Parallel | None + self._parallel_manager: Parallel[DataType] | None def zeros( self, @@ -1368,22 +1371,23 @@ def linspace( raise NotImplementedError("linspace is not implemented!") - def _register_callable( - self, fn: Callable | partial, fn_name: str, jit: bool + def _register_callable[T: Any]( + self, fn: Callable[..., T] | partial[T], fn_name: str, jit: bool ) -> None: raise NotImplementedError() - def _run_callable(self, *primals, fn_name: str): + def _run_callable(self, *primals: Any, fn_name: str) -> Any: raise NotImplementedError() - def _create_parallel(self, device_mesh: tuple[int, ...]): + def _create_parallel(self, device_mesh: tuple[int, ...]) -> None: raise NotImplementedError( - f"{self.type.capitalize()} backend does not support parallelization!" + f"{self.backend_type.capitalize()} " + + "backend does not support parallelization!" ) class UnavailableBackend: is_installed = False - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: Any, **kwargs: Any) -> None: raise RuntimeError("Backend is unavailable due to missing dependencies.") diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index b59ccf51..56ff6e39 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -44,7 +44,7 @@ class JaxBackend(ParallelBackend[jax.numpy.ndarray]): This argument controls whether JAX pre-allocates memory, default is False. """ - type = "jax" + backend_type = "jax" registered_primitives: dict[str, Callable[..., jax.numpy.ndarray]] = {} primitive_fn_path = "mithril.backends.with_autograd.jax_backend.ops" @@ -264,11 +264,11 @@ def array( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] result = self._conversion_fn_wrapper(jax.numpy.array)( - input, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + input, dtype=_dtype, device_mesh=device_mesh ) return result @@ -278,12 +278,12 @@ def zeros( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) result = self._creation_fn_wrapper(jax.numpy.zeros)( - _shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + _shape, dtype=_dtype, device_mesh=device_mesh ) return result @@ -293,12 +293,12 @@ def ones( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) result = self._creation_fn_wrapper(jax.numpy.ones)( - _shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + _shape, dtype=_dtype, device_mesh=device_mesh ) return result @@ -309,11 +309,11 @@ def ones_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] result = self._creation_fn_wrapper(jax.numpy.ones_like)( - input, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + input, dtype=_dtype, device_mesh=device_mesh ) return result @@ -324,11 +324,11 @@ def zeros_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] result = self._creation_fn_wrapper(jax.numpy.zeros_like)( - input, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + input, dtype=_dtype, device_mesh=device_mesh ) return result @@ -341,12 +341,12 @@ def randn( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) result = self._creation_fn_wrapper(jax.random.normal)( - prng_key, _shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + prng_key, _shape, dtype=_dtype, device_mesh=device_mesh ) return result @@ -359,12 +359,12 @@ def rand( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) result = self._creation_fn_wrapper(jax.random.uniform)( - prng_key, _shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + prng_key, _shape, dtype=_dtype, device_mesh=device_mesh ) return result @@ -379,16 +379,16 @@ def randint( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) result = self._creation_fn_wrapper(jax.random.randint)( prng_key, _shape, low, high, - dtype=utils.dtype_map[_dtype], + dtype=_dtype, device_mesh=device_mesh, ) return result @@ -404,14 +404,14 @@ def rand_uniform( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(jax.random.uniform)( prng_key, _shape, - dtype=utils.dtype_map[_dtype], + dtype=_dtype, minval=low, maxval=high, device_mesh=device_mesh, @@ -419,16 +419,16 @@ def rand_uniform( def _arange( self, - *args, + *args: int | float, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, - **kwargs, + **kwargs: Any, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._creation_fn_wrapper(jax.numpy.arange)( - *args, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + *args, dtype=_dtype, device_mesh=device_mesh ) def linspace( @@ -439,11 +439,11 @@ def linspace( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: str | None = None + _dtype: jax.numpy.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._creation_fn_wrapper(jax.numpy.linspace)( - start, stop, steps, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + start, stop, steps, dtype=_dtype, device_mesh=device_mesh ) def flatten( @@ -540,9 +540,9 @@ def atleast_2d( return jax.numpy.atleast_2d(inputs) def transpose( - self, data: jax.Array, axes: tuple[int, ...] | list[int] | None = None + self, input: jax.Array, axes: tuple[int, ...] | list[int] | None = None ) -> jax.Array: - return data.transpose(axes) + return input.transpose(axes) def unique( self, input: jax.Array, **kwargs: Any @@ -715,7 +715,7 @@ def vjp( (vjp,) = vjp return output, vjp, aux - def vmap( + def vmap( # type: ignore # mypy bug self, fn: Callable[..., dict[str, jax.Array]] ) -> Callable[..., dict[str, jax.Array]]: return jax.vmap(fn) diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index 002d1b36..2e332b39 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -24,7 +24,7 @@ ArrayType = jax.Array -dtype_map: dict[str | None, Any] = { +dtype_map: dict[str, jnp.dtype[Any]] = { "int16": jnp.int16, "int32": jnp.int32, "int": jnp.int32, @@ -36,7 +36,6 @@ "float64": jnp.float64, "double": jnp.float64, "bool": jnp.bool_, - None: None, } diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index ba5e858b..ec2a1a80 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -29,7 +29,7 @@ class MlxBackend(Backend[mx.array]): - type = "mlx" + backend_type = "mlx" supported_precisions = [16, 32] registered_primitives: dict[str, Callable[..., mx.array]] = {} primitive_fn_path = "mithril.backends.with_autograd.mlx_backend.ops" @@ -187,51 +187,41 @@ def _handle_sequence_type_fun( ] return [output] - def array(self, data: Any, *, dtype: Dtype | None = None) -> mx.array: - _dtype: str | None = None + def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array: + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._conversion_fn_wrapper(mx.array)( - data, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._conversion_fn_wrapper(mx.array)(input, dtype=_dtype) def zeros( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.zeros)( - shape=_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(mx.zeros)(shape=_shape, dtype=_dtype) def ones( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.ones)( - shape=_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(mx.ones)(shape=_shape, dtype=_dtype) def ones_like(self, input: mx.array, *, dtype: Dtype | None = None) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(mx.ones_like)( - input, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(mx.ones_like)(input, dtype=_dtype) def zeros_like(self, input: mx.array, *, dtype: Dtype | None = None) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(mx.zeros_like)( - input, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(mx.zeros_like)(input, dtype=_dtype) def randn( self, @@ -239,13 +229,11 @@ def randn( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.normal)( - shape=_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(mx.random.normal)(shape=_shape, dtype=_dtype) def rand( self, @@ -253,13 +241,11 @@ def rand( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.uniform)( - shape=_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(mx.random.uniform)(shape=_shape, dtype=_dtype) def randint( self, @@ -269,12 +255,12 @@ def randint( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(mx.random.randint)( - low=low, high=high, shape=_shape, dtype=utils.dtype_map[_dtype] + low=low, high=high, shape=_shape, dtype=_dtype ) def rand_uniform( @@ -285,21 +271,19 @@ def rand_uniform( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(mx.random.uniform)( - low=low, high=high, shape=_shape, dtype=utils.dtype_map[_dtype] + low=low, high=high, shape=_shape, dtype=_dtype ) def arange(self, *args: float | int, dtype: Dtype | None = None) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(mx.arange)( - *args, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(mx.arange)(*args, dtype=_dtype) def linspace( self, @@ -308,12 +292,10 @@ def linspace( steps: int | mx.array, dtype: Dtype | None = None, ) -> mx.array: - _dtype: str | None = None + _dtype: mx.Dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(mx.linspace)( - start, stop, steps, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(mx.linspace)(start, stop, steps, dtype=_dtype) def flatten( self, input: mx.array, start_dim: int = 0, end_dim: int = -1 @@ -687,7 +669,7 @@ def vjp( return output, vjp, aux - def vmap( + def vmap( # type: ignore #mypy bug self, fn: Callable[[mx.array], mx.array] ) -> Callable[[mx.array], mx.array]: return mx.vmap(fn) diff --git a/mithril/backends/with_autograd/mlx_backend/utils.py b/mithril/backends/with_autograd/mlx_backend/utils.py index 04291eca..e388b7ae 100644 --- a/mithril/backends/with_autograd/mlx_backend/utils.py +++ b/mithril/backends/with_autograd/mlx_backend/utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Callable, Sequence from functools import partial from typing import Any, TypeGuard @@ -26,7 +28,7 @@ ArrayType = mx.array -dtype_map: dict[str | None, Any] = { +dtype_map: dict[str, mx.Dtype] = { "int8": mx.int8, "int16": mx.int16, "short": mx.int16, @@ -38,7 +40,6 @@ "float32": mx.float32, "float": mx.float32, "bool": mx.bool_, # type: ignore - None: None, } diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index 9a2aa763..63e0facf 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -48,7 +48,7 @@ class TorchBackend(ParallelBackend[torch.Tensor]): The precision of the tensors, either 32 or 64, default is 32. """ - type = "torch" + backend_type = "torch" registered_primitives = {} primitive_fn_path = "mithril.backends.with_autograd.torch_backend.ops" @@ -267,7 +267,7 @@ def _create_parallel(self, device_mesh: tuple[int, ...]): self._raw_device_mesh ) - def _run_callable(self, *primals, fn_name: str): + def _run_callable(self, *primals: Any, fn_name: str): assert ( self._parallel_manager is not None ), "Parallel manager is not initialized!" @@ -298,11 +298,11 @@ def array( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._conversion_fn_wrapper(torch.tensor)( - input, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + input, dtype=_dtype, device_mesh=device_mesh ) def zeros( @@ -311,12 +311,12 @@ def zeros( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(torch.zeros)( - _shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + _shape, dtype=_dtype, device_mesh=device_mesh ) def ones( @@ -325,12 +325,12 @@ def ones( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(torch.ones)( - _shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + _shape, dtype=_dtype, device_mesh=device_mesh ) def ones_like( @@ -340,11 +340,11 @@ def ones_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._creation_fn_wrapper(torch.ones_like)( - input, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + input, dtype=_dtype, device_mesh=device_mesh ) def zeros_like( @@ -354,11 +354,11 @@ def zeros_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._creation_fn_wrapper(torch.zeros_like)( - input, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + input, dtype=_dtype, device_mesh=device_mesh ) def randn( @@ -368,12 +368,12 @@ def randn( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(torch.randn)( - size=_shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + size=_shape, dtype=_dtype, device_mesh=device_mesh ) def rand( @@ -383,12 +383,12 @@ def rand( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(torch.rand)( - size=_shape, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + size=_shape, dtype=_dtype, device_mesh=device_mesh ) def randint( @@ -400,15 +400,15 @@ def randint( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(torch.randint)( low, high, size=_shape, - dtype=utils.dtype_map[_dtype], + dtype=_dtype, device_mesh=device_mesh, ) @@ -432,11 +432,11 @@ def _arange( device_mesh: tuple[int, ...] | None = None, **kwargs: int | float, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._creation_fn_wrapper(torch.arange)( - *args, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + *args, dtype=_dtype, device_mesh=device_mesh ) def linspace( @@ -447,11 +447,11 @@ def linspace( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: str | None = None + _dtype: torch.dtype | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] return self._creation_fn_wrapper(torch.linspace)( - start, stop, steps, dtype=utils.dtype_map[_dtype], device_mesh=device_mesh + start, stop, steps, dtype=_dtype, device_mesh=device_mesh ) def flatten( @@ -708,7 +708,9 @@ def vjp( (vjp,) = vjp return output, vjp, aux - def vmap(self, fn: Callable[..., dict[str, torch.Tensor]]) -> Callable: + def vmap( # type: ignore # mypy bug + self, fn: Callable[..., dict[str, torch.Tensor]] + ) -> Callable[..., dict[str, torch.Tensor]]: return torch_vmap(fn) def jacrev(self, fn: Callable[..., dict[str, torch.Tensor]]) -> Callable: diff --git a/mithril/backends/with_autograd/torch_backend/stensor.py b/mithril/backends/with_autograd/torch_backend/stensor.py index 497df3a5..3ee26445 100644 --- a/mithril/backends/with_autograd/torch_backend/stensor.py +++ b/mithril/backends/with_autograd/torch_backend/stensor.py @@ -13,6 +13,7 @@ # limitations under the License. from collections.abc import Callable +from typing import Any import torch from torch.distributed._tensor import DTensor @@ -41,7 +42,13 @@ def extract_ref(data): return data @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__( + cls, + func: Any, + types: Any, + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ): operator_name = func._name.split("::")[1].split(".")[0] args_ref = STensor.extract_ref(args) kwargs_ref = STensor.extract_ref(kwargs) diff --git a/mithril/backends/with_autograd/torch_backend/utils.py b/mithril/backends/with_autograd/torch_backend/utils.py index 272cd9e9..b5254171 100644 --- a/mithril/backends/with_autograd/torch_backend/utils.py +++ b/mithril/backends/with_autograd/torch_backend/utils.py @@ -36,7 +36,7 @@ AVAILABLE_BACKEND_TYPES = ["cpu", "cuda"] ArrayType = torch.Tensor -dtype_map: dict[str | None, torch.dtype | None] = { +dtype_map: dict[str, torch.dtype] = { "int16": torch.int16, "int32": torch.int32, "int": torch.int32, @@ -48,7 +48,6 @@ "float64": torch.float64, "double": torch.float64, "bool": torch.bool, - None: None, } diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index b09dadfa..7758166e 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -63,7 +63,7 @@ def zeros( _shape = process_shape(shape) return array.zeros(_shape) - def zeros_like(self, input: PyArray, dtype: core.Dtype | None = None) -> PyArray: + def zeros_like(self, input: PyArray, *, dtype: core.Dtype | None = None) -> PyArray: assert dtype is None, "dtype is not supported in CBackend" return self.array(np.zeros(input.shape, dtype=np.float32)) @@ -71,7 +71,7 @@ def to_numpy(self, array: PyArray) -> np.ndarray[Any, Any]: return utils.to_numpy(array) def array( - self, input: np.ndarray[Any, Any], dtype: core.Dtype | None = None + self, input: np.ndarray[Any, Any], *, dtype: core.Dtype | None = None ) -> PyArray: assert dtype is None, "dtype is not supported in CBackend" input = input.astype(np.float32) diff --git a/mithril/backends/with_manualgrad/c_backend/src/array.py b/mithril/backends/with_manualgrad/c_backend/src/array.py index 8a28d250..c0af80ba 100644 --- a/mithril/backends/with_manualgrad/c_backend/src/array.py +++ b/mithril/backends/with_manualgrad/c_backend/src/array.py @@ -51,7 +51,7 @@ class Array(ctypes.Structure): lib.delete_struct.argtypes = [ctypes.POINTER(Array)] -def to_c_int_array(lst: list[Any]): +def to_c_int_array(lst: list[Any] | tuple[Any, ...]): return (ctypes.c_int * len(lst))(*lst) @@ -81,7 +81,7 @@ def data(self): data_list = [data_ptr[i] for i in range(total_elements)] # Reshape the flat list based on the shape - def reshape(data, shape): + def reshape(data: PyArray, shape: tuple[int, ...]): if len(shape) == 1: return data size = shape[0] diff --git a/mithril/backends/with_manualgrad/common_primitives.py b/mithril/backends/with_manualgrad/common_primitives.py index fe90fa9b..538a159d 100644 --- a/mithril/backends/with_manualgrad/common_primitives.py +++ b/mithril/backends/with_manualgrad/common_primitives.py @@ -229,7 +229,11 @@ def to_list(*args: tuple[int | float | bool, ...], cache: CacheType = None): return list(args) -def padding_converter_1d(input, kernel_size, cache: CacheType = None): +def padding_converter_1d( + input: PaddingType | int | tuple[int, int], + kernel_size: int | tuple[int, int], + cache: CacheType = None, +) -> tuple[int, int]: if isinstance(input, PaddingType): if input == PaddingType.VALID: output = (0, 0) @@ -246,45 +250,50 @@ def padding_converter_1d(input, kernel_size, cache: CacheType = None): elif isinstance(input, int): output = (input, input) - elif isinstance(input, Sequence): + else: if isinstance(input[0], Sequence) or isinstance(input[1], Sequence): raise RuntimeError(f"Given input '{input}' is not valid!") - output = tuple(input) + output = input return output -def padding_converter_2d(input, kernel_size, cache: CacheType = None): - output: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] +def padding_converter_2d( + input: PaddingType + | int + | tuple[int, int] + | tuple[tuple[int, int] | tuple[int, int]], + kernel_size: int | tuple[int, int], + cache: CacheType = None, +) -> tuple[tuple[int, int], tuple[int, int]]: if isinstance(input, PaddingType): if input == PaddingType.VALID: - output = (0, 0) + output = ((0, 0), (0, 0)) elif isinstance(kernel_size, tuple): if kernel_size[0] % 2 == 0 or kernel_size[1] % 2 == 0: raise RuntimeError( "'same' padding is not supported when the kernel size is even!" ) - output = (kernel_size[0] // 2, kernel_size[1] // 2) - elif isinstance(kernel_size, int): + output = ( + (kernel_size[0] // 2, kernel_size[1] // 2), + (kernel_size[0] // 2, kernel_size[1] // 2), + ) + else: if kernel_size % 2 == 0: raise RuntimeError( "'same' padding is not supported when the kernel size is even!" ) half = kernel_size // 2 output = ((half, half), (half, half)) - else: - raise RuntimeError("Kernel size must be 'tuple[int, int]' or 'int'!") elif isinstance(input, int): - output = (input, input) - elif isinstance(input, Sequence): - _output = [] + output = ((input, input), (input, input)) + else: + _output: list[tuple[int, int]] = [] for p in input: if isinstance(p, int): _output.append((p, p)) - elif isinstance(input, Sequence) and len(p) == 2: - _output.append(tuple(p)) - else: - raise RuntimeError(f"Given input '{input}' is not valid!") + elif len(p) == 2: + _output.append(p) output = ((_output[0][0], _output[0][1]), (_output[1][0], _output[1][1])) return output @@ -304,14 +313,18 @@ def swapaxes( return input.swapaxes(axis1, axis2) -def stride_converter(input, kernel_size, cache: CacheType = None): +def stride_converter( + input: int | tuple[int, int] | None, + kernel_size: int | tuple[int, int], + cache: CacheType = None, +): if input is None: return kernel_size else: return input -def tuple_converter(input, cache: CacheType = None): +def tuple_converter(input: int | tuple[int, int], cache: CacheType = None): if isinstance(input, int): return (input, input) else: diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index 62add425..ae870f24 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -37,7 +37,7 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): The precision of the arrays, either 32 or 64, default is 32. """ - type = "numpy" + backend_type = "numpy" registered_primitives = {} primitive_fn_path = "mithril.backends.with_manualgrad.numpy_backend.ops" @@ -156,53 +156,45 @@ def accumulate_grads( ) -> np.ndarray[Any, Any]: return utils.accumulate_grads(gradient, input, cache, idx) - def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]: - _dtype: str | None = None + def array(self, input: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._conversion_fn_wrapper(np.array)( - data, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._conversion_fn_wrapper(np.array)(input, dtype=_dtype) def zeros( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None - ) -> np.ndarray: - _dtype: str | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(np.zeros)( - shape=_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(np.zeros)(shape=_shape, dtype=_dtype) def ones( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None - ) -> np.ndarray: - _dtype: str | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(np.ones)( - shape=_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(np.ones)(shape=_shape, dtype=_dtype) - def ones_like(self, input: np.ndarray, *, dtype: Dtype | None = None) -> np.ndarray: - _dtype: str | None = None + def ones_like( + self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(np.ones_like)( - input, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(np.ones_like)(input, dtype=_dtype) def zeros_like( self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: str | None = None + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(np.zeros_like)( - input, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(np.zeros_like)(input, dtype=_dtype) def randn( self, @@ -210,27 +202,23 @@ def randn( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: str | None = None + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.randn)( - *_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(np.random.randn)(*_shape, dtype=_dtype) def rand( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None, prng_key: Any = None, - ) -> np.ndarray: - _dtype: str | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.rand)( - *_shape, dtype=utils.dtype_map[_dtype] - ) + return self._creation_fn_wrapper(np.random.rand)(*_shape, dtype=_dtype) def randint( self, @@ -239,109 +227,113 @@ def randint( *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None, prng_key: Any = None, - ) -> np.ndarray: - _dtype: str | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(np.random.randint)( - low=low, high=high, size=_shape, dtype=utils.dtype_map[_dtype] + low=low, high=high, size=_shape, dtype=_dtype ) def rand_uniform( self, - low: int | float | bool | np.ndarray, - high: int | float | bool | np.ndarray, + low: int | float | bool | np.ndarray[Any, Any], + high: int | float | bool | np.ndarray[Any, Any], *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None, prng_key: Any = None, - ) -> np.ndarray: - _dtype: str | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name + _dtype = utils.dtype_map[dtype.name] _shape = process_shape(shape) return self._creation_fn_wrapper(np.random.uniform)( - low=low, high=high, size=_shape, dtype=utils.dtype_map[_dtype] + low=low, high=high, size=_shape, dtype=_dtype ) - def arange(self, *args, dtype: Dtype | None = None) -> np.ndarray: - _dtype: str | None = None + def arange( + self, *args: int | float, dtype: Dtype | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(np.arange)( - *args, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(np.arange)(*args, dtype=_dtype) def linspace( self, - start: int | float | bool | np.ndarray, - stop: int | float | bool | np.ndarray, - steps: int | np.ndarray, + start: int | float | bool | np.ndarray[Any, Any], + stop: int | float | bool | np.ndarray[Any, Any], + steps: int | np.ndarray[Any, Any], dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, - ) -> np.ndarray: - _dtype: str | None = None + ) -> np.ndarray[Any, Any]: + _dtype: np.dtype[Any] | None = None if isinstance(dtype, Dtype): - _dtype = dtype.name - return self._creation_fn_wrapper(np.linspace)( - start, stop, steps, dtype=utils.dtype_map[_dtype] - ) + _dtype = utils.dtype_map[dtype.name] + return self._creation_fn_wrapper(np.linspace)(start, stop, steps, dtype=_dtype) def flatten( - self, input: np.ndarray, start_dim: int = 0, end_dim: int = -1 - ) -> np.ndarray: + self, input: np.ndarray[Any, Any], start_dim: int = 0, end_dim: int = -1 + ) -> np.ndarray[Any, Any]: return ops.flatten(input, start_dim=start_dim, end_dim=end_dim) - def abs(self, input: np.ndarray) -> np.ndarray: + def abs(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.abs(input) - def sign(self, input: np.ndarray) -> np.ndarray: + def sign(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.sign(input) - def sin(self, input: np.ndarray) -> np.ndarray: + def sin(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.sin(input) - def cos(self, input: np.ndarray) -> np.ndarray: + def cos(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.cos(input) - def tanh(self, input: np.ndarray) -> np.ndarray: + def tanh(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.tanh(input) - def relu(self, input: np.ndarray) -> np.ndarray: + def relu(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return ops.relu(input) - def leaky_relu(self, input: np.ndarray, slope: float | np.ndarray) -> np.ndarray: + def leaky_relu( + self, input: np.ndarray[Any, Any], slope: float | np.ndarray[Any, Any] + ) -> np.ndarray[Any, Any]: return ops.leaky_relu(input, slope) - def sigmoid(self, input: np.ndarray) -> np.ndarray: + def sigmoid(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return ops.sigmoid(input) - def softplus(self, input: np.ndarray) -> np.ndarray: + def softplus(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return ops.softplus(input) - def softmax(self, input: np.ndarray, dim: int = -1) -> np.ndarray: + def softmax( + self, input: np.ndarray[Any, Any], dim: int = -1 + ) -> np.ndarray[Any, Any]: # TODO: dim can be Sequence[int] as well. Should work # for all backends. return ops.softmax(input, axis=dim) - def log(self, input: np.ndarray) -> np.ndarray: + def log(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.log(input) - def isnan(self, input: np.ndarray) -> np.ndarray: + def isnan(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.isnan(input) - def stop_gradient(self, data: np.ndarray) -> np.ndarray: - return ops.stop_gradient(data) + def stop_gradient(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: + return ops.stop_gradient(input) - def squeeze(self, input: np.ndarray) -> np.ndarray: + def squeeze(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.squeeze(input) - def reshape(self, input: np.ndarray, shape: tuple[int, ...]) -> np.ndarray: + def reshape( + self, input: np.ndarray[Any, Any], shape: tuple[int, ...] + ) -> np.ndarray[Any, Any]: return np.reshape(input, shape) def sort( - self, input: np.ndarray, axis: int = -1, descending: bool = False - ) -> np.ndarray: + self, input: np.ndarray[Any, Any], axis: int = -1, descending: bool = False + ) -> np.ndarray[Any, Any]: if descending: return -np.sort(-input, axis=axis) return np.sort( @@ -349,50 +341,63 @@ def sort( axis=axis, ) - def expand_dims(self, input: np.ndarray, axis: int) -> np.ndarray: + def expand_dims( + self, input: np.ndarray[Any, Any], axis: int + ) -> np.ndarray[Any, Any]: return np.expand_dims(input, axis) - def stack(self, inputs: list[np.ndarray], axis: int = 0) -> np.ndarray: + def stack( + self, inputs: list[np.ndarray[Any, Any]], axis: int = 0 + ) -> np.ndarray[Any, Any]: return np.stack(inputs, axis=axis) def cat( - self, inputs: tuple[np.ndarray, ...] | list[np.ndarray], axis: int = 0 - ) -> np.ndarray: + self, + inputs: tuple[np.ndarray[Any, Any], ...] | list[np.ndarray[Any, Any]], + axis: int = 0, + ) -> np.ndarray[Any, Any]: return ops.concat(*inputs, axis=axis) - def pad(self, input: np.ndarray, pad_width: PadWidthType) -> np.ndarray: + def pad( + self, input: np.ndarray[Any, Any], pad_width: PadWidthType + ) -> np.ndarray[Any, Any]: return np.pad(input, pad_width) - def all(self, input: np.ndarray) -> np.ndarray: + def all(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.array(np.all(input)) - def any(self, input: np.ndarray) -> np.ndarray: + def any(self, input: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: return np.array(np.any(input)) def atleast_1d( - self, inputs: np.ndarray | tuple[np.ndarray, ...] - ) -> np.ndarray | tuple[np.ndarray, ...]: + self, inputs: np.ndarray[Any, Any] | tuple[np.ndarray[Any, Any], ...] + ) -> np.ndarray[Any, Any] | tuple[np.ndarray[Any, Any], ...]: if isinstance(inputs, tuple): return np.atleast_1d(*inputs) else: return np.atleast_1d(inputs) def atleast_2d( - self, inputs: np.ndarray | tuple[np.ndarray, ...] - ) -> np.ndarray | tuple[np.ndarray, ...]: + self, inputs: np.ndarray[Any, Any] | tuple[np.ndarray[Any, Any], ...] + ) -> np.ndarray[Any, Any] | tuple[np.ndarray[Any, Any], ...]: if isinstance(inputs, tuple): return np.atleast_2d(*inputs) else: return np.atleast_2d(inputs) def transpose( - self, input: np.ndarray, axes: tuple[int, ...] | list[int] | None = None - ) -> np.ndarray: + self, + input: np.ndarray[Any, Any], + axes: tuple[int, ...] | list[int] | None = None, + ) -> np.ndarray[Any, Any]: return ops.transpose(input, axes) def where( - self, cond: np.ndarray, input1: np.ndarray, input2: np.ndarray - ) -> np.ndarray: + self, + cond: np.ndarray[Any, Any], + input1: np.ndarray[Any, Any], + input2: np.ndarray[Any, Any], + ) -> np.ndarray[Any, Any]: return ops.where(cond, input1, input2) # TODO: Analyze the code's efficiency and refactor it if necessary. @@ -400,20 +405,20 @@ def where( # TODO: Now topk only supports one dimensional tensors, # add multi-dimensional support similar to torch, jax and mlx - def topk(self, array: np.ndarray, k: int) -> np.ndarray: - flat = array.ravel() + def topk(self, input: np.ndarray[Any, Any], k: int) -> np.ndarray[Any, Any]: + flat = input.ravel() indices = np.argpartition(flat, -k)[-k:] argsort = np.argsort(-flat[indices]) indices = indices[argsort] values = flat[indices] - leading_dims = len(array.shape) - len(values.shape) + leading_dims = len(input.shape) - len(values.shape) values = values.reshape((-1,) * leading_dims + values.shape) return values def multinomial( - self, probs: np.ndarray, num_samples: int, replacement: bool = False - ) -> np.ndarray: + self, probs: np.ndarray[Any, Any], num_samples: int, replacement: bool = False + ) -> np.ndarray[Any, Any]: # input = np.asarray(probs) if probs.ndim == 1: probs = probs[None, :] diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index 3b1735df..6132d17a 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -424,7 +424,7 @@ def reduce_sum( axis: int | tuple[int, ...] | None = None, keepdim: bool = False, cache: CacheType | None = None, -) -> np.ndarray: +) -> np.ndarray[Any, Any]: return np.sum(input, axis=axis, keepdims=keepdim) @@ -929,7 +929,7 @@ def broadcast_to( def ones_with_zero_diag( - *args, precision: int, cache: CacheType | None = None + *args: Any, precision: int, cache: CacheType | None = None ) -> np.ndarray[Any, Any]: n, m = args output = np.ones((n, m)) - np.eye(n, m) if m is not None else np.ones(n) - np.eye(n) @@ -983,7 +983,7 @@ def concat( def arange( - *args, precision: int, cache: CacheType | None = None + *args: int | float, precision: int, cache: CacheType | None = None ) -> np.ndarray[Any, Any]: return handle_data_precision(np.arange(*args), precision) @@ -1065,7 +1065,7 @@ def polynomial_features( identity = np.eye(dims + 1, dims + 1, dtype=int) data = np.hstack((np.ones((samples, 1), dtype=input.dtype), input)) write_into_cache(cache, "data", data) - powers: Iterator = map(sum, combinations_with_replacement(identity, degree)) + powers: Iterator[int] = map(sum, combinations_with_replacement(identity, degree)) # Skip first element of powers. This is the bias term. next(powers) write_into_cache( diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index 3cbab347..c7a101ca 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -24,7 +24,7 @@ ArrayType = np.ndarray -dtype_map = { +dtype_map: dict[str, Any] = { "int16": np.int16, "int32": np.int32, "int": np.int32, @@ -36,7 +36,6 @@ "float64": np.float64, "double": np.float64, "bool": np.bool_, - None: None, } CacheType = dict[str, Any] diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 97d1c386..9fecb178 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -252,7 +252,8 @@ def import_backend(self): module="mithril", names=[ ast.alias( - name=f"{self.pm.backend.type.capitalize()}Backend", asname="Backend" + name=f"{self.pm.backend.backend_type.capitalize()}Backend", + asname="Backend", ) ], level=0, diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 182b7a5d..b80559aa 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -420,7 +420,7 @@ def _flatten_dag( self._infer_differentiability(model, dag) # NOTE: maybe move adding cache to generate_code methods. - if self.backend.type == "numpy": + if self.backend.backend_type == "numpy": cache_name = "_".join([dag[output], model.cache_name]) dag["cache"] = cache_name cache_value: dict[str, MainValueType] | None = ( @@ -947,7 +947,7 @@ def _print_model_info( output_keys = pm_output_keys pm_info = { - "Backend type": [self.backend.type], + "Backend type": [self.backend.backend_type], "Backend precision": [str(self.backend.precision)], "Backend device": [str(self.backend.device)], "Output keys": sorted(output_keys), diff --git a/mithril/utils/utils.py b/mithril/utils/utils.py index 8a51a644..f4cc7cb9 100755 --- a/mithril/utils/utils.py +++ b/mithril/utils/utils.py @@ -344,7 +344,7 @@ def find_slot_lengths(stacked_data, lengths): data_sorted = data max_size = lengths[0] - if backend.type == "torch": + if backend.backend_type == "torch": padded = [ backend.pad(arr[index], (0, 0, 0, max_size - arr[index].shape[0])) for arr in data diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 29d789ab..3fddfc3d 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -141,7 +141,7 @@ def evaluate_case( # assert set(outputs.keys()) == set(reference_outputs) for k, v in reference_outputs.items(): if isinstance(v, dict): - v = v[backend.type] + v = v[backend.backend_type] out = outputs.get(k, None) # We may not include any reference value for some keys for a certain test. # So we don't assert set(outputs.keys()) == set(reference_outputs) since @@ -166,7 +166,7 @@ def evaluate_case( for k, v in reference_gradients.items(): if isinstance(v, dict): - v = v[backend.type] + v = v[backend.backend_type] grad = model_grad[k] if grad is None: assert v == grad diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 64982d8b..15c4e09e 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -160,7 +160,7 @@ def compile_and_compare( for k, v in backend_ref_outputs.items(): if isinstance(v, dict): - v = v[backend.type] + v = v[backend.backend_type] out = outputs.get(k, None) # We may not include any reference value for some keys for a certain test. # So we don't assert set(outputs.keys()) == set(reference_outputs) since @@ -205,7 +205,7 @@ def compile_and_compare( for k, v in backend_ref_gradients.items(): if isinstance(v, dict): - v = v[backend.type] + v = v[backend.backend_type] grad = gradients[k] if grad is None: assert v == grad @@ -2415,7 +2415,7 @@ def test_cast_int16(): res = pm.evaluate() res_out = res["output"] assert isinstance(res_out, backend.DataType) - assert res_out.dtype == expected_dtypes[backend.type] + assert res_out.dtype == expected_dtypes[backend.backend_type] np.testing.assert_allclose(res_out, reference_outputs["output"]) @@ -2461,7 +2461,7 @@ def test_cast_int32(): res = pm.evaluate() res_out = res["output"] assert isinstance(res_out, backend.DataType) # type: ignore - assert res_out.dtype == expected_dtypes[backend.type] + assert res_out.dtype == expected_dtypes[backend.backend_type] np.testing.assert_allclose(res_out, reference_outputs["output"]) @@ -2505,7 +2505,7 @@ def test_cast_int64(): inference=True, ) res = pm.evaluate() - assert res["output"].dtype == expected_dtypes[backend.type] # type: ignore + assert res["output"].dtype == expected_dtypes[backend.backend_type] # type: ignore np.testing.assert_allclose(res["output"], reference_outputs["output"]) # type: ignore @@ -2550,7 +2550,7 @@ def test_cast_float16(): ) res = pm.evaluate()["output"] assert isinstance(res, backend.DataType) - assert res.dtype == expected_dtypes[backend.type] + assert res.dtype == expected_dtypes[backend.backend_type] np.testing.assert_allclose(res, reference_outputs["output"]) @@ -2596,7 +2596,7 @@ def test_cast_float32(): res = pm.evaluate() res_out = res["output"] assert isinstance(res_out, backend.DataType) # type: ignore - assert res_out.dtype == expected_dtypes[backend.type] + assert res_out.dtype == expected_dtypes[backend.backend_type] np.testing.assert_allclose(res_out, reference_outputs["output"]) @@ -2638,7 +2638,7 @@ def test_cast_float64(): res = pm.evaluate() res_out = res["output"] assert isinstance(res_out, backend.DataType) # type: ignore - assert res_out.dtype == expected_dtypes[backend.type] + assert res_out.dtype == expected_dtypes[backend.backend_type] np.testing.assert_allclose(res_out, reference_outputs["output"]) @@ -2684,7 +2684,7 @@ def test_cast_bool(): res = pm.evaluate() res_out = res["output"] assert isinstance(res_out, backend.DataType) # type: ignore - assert res_out.dtype == expected_dtypes[backend.type] + assert res_out.dtype == expected_dtypes[backend.backend_type] np.testing.assert_allclose(res_out, reference_outputs["output"]) diff --git a/tests/scripts/test_backend_fns.py b/tests/scripts/test_backend_fns.py index c1e22baf..6d743390 100644 --- a/tests/scripts/test_backend_fns.py +++ b/tests/scripts/test_backend_fns.py @@ -122,8 +122,8 @@ def assert_backend_results_equal( for out, ref in zip(output, ref_output, strict=False): assert tuple(out.shape) == tuple(ref.shape) - assert get_array_device(out, backend.type) == ref_output_device - assert get_array_precision(out, backend.type) == ref_output_precision + assert get_array_device(out, backend.backend_type) == ref_output_device + assert get_array_precision(out, backend.backend_type) == ref_output_precision assert testing_fn(out, ref, rtol=rtol, atol=atol) diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 3c1d6c6c..080f4b95 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -116,7 +116,7 @@ def assert_all_backends_device_precision(model: Model): # remove unsupported backend, device and precision trios if (backend_class, device, precision) in unsupported_device_precisions: continue - _type = backend_class.type + _type = backend_class.backend_type backend = backend_class(device=device, precision=precision) comp_model = mithril.compile( diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index fcb78758..c7454409 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -145,7 +145,7 @@ def test_randomized(case: str) -> None: floated_randomized_args = current_case["model"].pop("floats", {}) regular_args = current_case["model"].pop("regular_args", {}) init_backend = avaliable_backends.pop(0) - init_key = init_backend.type + init_key = init_backend.backend_type static_input_info = current_case.pop("static_input_info", {}) input_info: dict[str, dict[str, list]] = current_case.pop("input_info", {}) @@ -219,14 +219,14 @@ def test_randomized(case: str) -> None: if key not in compiled_model.ignore_grad_keys } for backend in avaliable_backends: - output_gradients[backend.type] = { + output_gradients[backend.backend_type] = { key: backend.array(value) for key, value in output_gradients[init_key].items() } - inputs[backend.type] = { + inputs[backend.backend_type] = { key: backend.array(value) for key, value in inputs[init_key].items() } - static_inputs[backend.type] = { + static_inputs[backend.backend_type] = { key: backend.array(value) if isinstance(model.conns._get_metadata(key).data, Tensor) else value @@ -240,15 +240,15 @@ def test_randomized(case: str) -> None: for backend in avaliable_backends: compiled_model = compile( model=model, - constant_keys=static_inputs[backend.type], + constant_keys=static_inputs[backend.backend_type], backend=backend, # type: ignore[reportArgumentType] shapes=shapes, jit=True, ) - outputs[backend.type], gradients[backend.type] = ( + outputs[backend.backend_type], gradients[backend.backend_type] = ( compiled_model.evaluate_all( - inputs[backend.type], - output_gradients=output_gradients[backend.type], + inputs[backend.backend_type], + output_gradients=output_gradients[backend.backend_type], ) ) @@ -281,16 +281,17 @@ def test_randomized(case: str) -> None: assert numeric_value == inferred_shapes for backend in avaliable_backends: - outputs[backend.type] = { - key: np.array(value) for key, value in outputs[backend.type].items() + outputs[backend.backend_type] = { + key: np.array(value) + for key, value in outputs[backend.backend_type].items() } - gradients[backend.type] = { + gradients[backend.backend_type] = { key: np.array(value) - for key, value in gradients[backend.type].items() + for key, value in gradients[backend.backend_type].items() } for backend in avaliable_backends: - for k, v in outputs[backend.type].items(): + for k, v in outputs[backend.backend_type].items(): np.testing.assert_allclose( outputs["numpy"][k], v, @@ -299,7 +300,7 @@ def test_randomized(case: str) -> None: ) for backend in avaliable_backends: - for k, v in gradients[backend.type].items(): + for k, v in gradients[backend.backend_type].items(): np.testing.assert_allclose( gradients["numpy"][k], v, diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 99a52e97..44235a53 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -1297,9 +1297,9 @@ def test_logical_model_compile_twice(): constant_keys=static_keys_torch, ) - assert torch_model.backend.type == "torch" - assert jax_model.backend.type == "jax" - assert np_model.backend.type == "numpy" + assert torch_model.backend.backend_type == "torch" + assert jax_model.backend.backend_type == "jax" + assert np_model.backend.backend_type == "numpy" def test_canonical_output_compile(): From 3ad4dfeb7b423cbd03418528b2b731d2fc307304 Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Thu, 26 Dec 2024 17:49:47 +0300 Subject: [PATCH 22/26] underscored attributes are edited --- mithril/__init__.py | 5 +- mithril/backends/backend.py | 11 +- .../with_autograd/jax_backend/backend.py | 5 +- .../with_autograd/mlx_backend/backend.py | 3 + .../with_autograd/torch_backend/backend.py | 7 +- .../with_autograd/torch_backend/parallel.py | 2 +- .../with_autograd/torch_backend/stensor.py | 2 +- .../with_manualgrad/numpy_backend/backend.py | 3 + mithril/framework/codegen/c_gen.py | 18 +- mithril/framework/codegen/numpy_gen.py | 44 ++-- mithril/framework/codegen/python_gen.py | 26 +- mithril/framework/codegen/torch_gen.py | 4 +- mithril/framework/codegen/utils.py | 2 +- mithril/framework/common.py | 133 ++++++---- mithril/framework/constraints.py | 154 ++++++------ mithril/framework/logical/base.py | 79 +++--- .../framework/logical/essential_primitives.py | 4 +- mithril/framework/logical/model.py | 137 +++++------ mithril/framework/logical/primitive.py | 38 +-- mithril/framework/physical/data_store.py | 14 +- mithril/framework/physical/flat_graph.py | 2 +- mithril/framework/physical/model.py | 90 +++---- mithril/models/models.py | 66 ++--- mithril/models/primitives.py | 18 +- mithril/models/train_model.py | 60 ++--- mithril/utils/dict_conversions.py | 22 +- mithril/utils/utils.py | 57 +---- tests/scripts/helper.py | 16 +- tests/scripts/test_constant_inputs.py | 6 +- tests/scripts/test_constr_counter.py | 4 +- tests/scripts/test_constraints.py | 8 +- tests/scripts/test_data_store.py | 36 +-- tests/scripts/test_differentiablity.py | 24 +- tests/scripts/test_extend_template.py | 14 +- tests/scripts/test_functions.py | 6 +- tests/scripts/test_inference.py | 4 +- tests/scripts/test_io_key.py | 16 +- tests/scripts/test_key_namings.py | 68 +++--- tests/scripts/test_key_values_in_init.py | 34 +-- tests/scripts/test_primitive_calls.py | 2 +- .../test_randomized_models_all_backends.py | 4 +- tests/scripts/test_scripts.py | 228 +++++++----------- tests/scripts/test_set_outputs.py | 8 +- tests/scripts/test_set_types.py | 54 ++--- tests/scripts/test_shapes.py | 58 ++--- tests/scripts/test_train_context.py | 16 +- tests/scripts/test_type_coercion.py | 100 ++++---- tests/scripts/test_type_consistencies.py | 48 ++-- tests/scripts/test_utils.py | 10 +- tests/utils.py | 4 +- 50 files changed, 877 insertions(+), 897 deletions(-) diff --git a/mithril/__init__.py b/mithril/__init__.py index 524e2e69..f1589072 100644 --- a/mithril/__init__.py +++ b/mithril/__init__.py @@ -126,11 +126,8 @@ def compile( raise Exception("Model is not jittable. Can only be compiled with jit = False.") # TrainModel model requires to be finalized before compilation. if isinstance(model, TrainModel): - model._finalize() + model.finalize() - # Generate Physical Model. - if not isinstance(model, BaseModel): - raise Exception("Unsupported model type!") if model.parent is not None: raise ValueError("Model with a parent could not be compiled!") diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index d2562b72..4c456659 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -64,6 +64,9 @@ def precision(self): def device(self): return self._device + def get_device(self): + return self._device + @property def inf(self) -> DataType | float: raise NotImplementedError("inf is not implemented") @@ -1055,6 +1058,12 @@ def __init__(self, device_mesh: tuple[int, ...] | None) -> None: self.n_devices = math.prod(device_mesh) if device_mesh is not None else 1 self._parallel_manager: Parallel[DataType] | None + def get_parallel_manager(self) -> Parallel[DataType] | None: + return self._parallel_manager + + def get_raw_device_mesh(self) -> tuple[int, ...] | None: + return self._raw_device_mesh + def zeros( self, *shape: int | tuple[int, ...] | list[int], @@ -1371,7 +1380,7 @@ def linspace( raise NotImplementedError("linspace is not implemented!") - def _register_callable[T: Any]( + def register_callable[T: Any]( self, fn: Callable[..., T] | partial[T], fn_name: str, jit: bool ) -> None: raise NotImplementedError() diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 56ff6e39..d3337e3f 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -90,6 +90,9 @@ def get_backend_array_type(self): def device(self): return utils.get_device(self._device) + def get_device(self): + return self._device + # TODO: This property is weird! Investigate why this property is used. @property def DataType(self): # noqa: N802 @@ -236,7 +239,7 @@ def _parallelize( return tensor return self._parallel_manager.parallelize(tensor, device_mesh) - def _register_callable( + def register_callable( self, fn: Callable[..., Any], fn_name: str, jit: bool = False ): assert ( diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index ec2a1a80..1b623ce7 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -64,6 +64,9 @@ def nan(self): def device(self): utils.get_device(self._device) + def get_device(self): + return self._device + @property def DataType(self): # noqa: N802 return utils.ArrayType diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index 63e0facf..5b5c39e2 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -96,6 +96,9 @@ def device(self): def get_backend_array_type(self): return torch.Tensor + def get_device(self): + return self._device + @staticmethod def register_primitive(fn: Callable[..., Any]) -> None: TorchBackend.registered_primitives[fn.__name__] = fn @@ -234,7 +237,7 @@ def _parallelize( tensor, self.base_device_mesh, device_mesh ) - def _register_callable( + def register_callable( self, fn: Callable[..., torch.Tensor], fn_name: str, jit: bool = False ): """ @@ -263,7 +266,7 @@ def _create_parallel(self, device_mesh: tuple[int, ...]): self._parallel_manager = TorchParallel( self.n_devices, device=self._device.split(":")[0] ) - self.base_device_mesh = self._parallel_manager._init_device_mesh( + self.base_device_mesh = self._parallel_manager.init_device_mesh( self._raw_device_mesh ) diff --git a/mithril/backends/with_autograd/torch_backend/parallel.py b/mithril/backends/with_autograd/torch_backend/parallel.py index acea8631..5b79f742 100644 --- a/mithril/backends/with_autograd/torch_backend/parallel.py +++ b/mithril/backends/with_autograd/torch_backend/parallel.py @@ -120,7 +120,7 @@ def _init_processes(self): atexit.register(self.clean_up) - def _init_device_mesh(self, mesh_shape: tuple[int, ...]) -> DeviceMesh: + def init_device_mesh(self, mesh_shape: tuple[int, ...]) -> DeviceMesh: if mesh_shape not in TorchParallel.device_meshes: self._send_instrcs(Instructions.INIT_MESH, None, mesh_shape, None) TorchParallel.device_meshes[mesh_shape] = init_device_mesh( diff --git a/mithril/backends/with_autograd/torch_backend/stensor.py b/mithril/backends/with_autograd/torch_backend/stensor.py index 3ee26445..1ce0e416 100644 --- a/mithril/backends/with_autograd/torch_backend/stensor.py +++ b/mithril/backends/with_autograd/torch_backend/stensor.py @@ -58,7 +58,7 @@ def __torch_dispatch__( dtensor = DTensor._op_dispatcher.dispatch(func, args, kwargs or {}) if isinstance(dtensor, torch.Tensor): - stensor = STensor.from_dtensor(dtensor) # type: ignore + stensor = STensor.from_dtensor(dtensor) else: stensor = dtensor diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index ae870f24..d1fa14f2 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -79,6 +79,9 @@ def DataType(self): # noqa: N802 def get_backend_array_type(self): return np.ndarray + def get_device(self): + return self._device + @staticmethod def get_available_devices() -> list[str]: """Static method to get available devices. Currently, in the NumpyBackend, diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index 5f9e32bb..aba85b5e 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -164,7 +164,7 @@ def evaluate_gradients_wrapper( # Create gradients for all params for key in ( - self.pm._flat_graph.all_source_keys + self.pm.flat_graph.all_source_keys - self.pm.data_store.all_static_keys - self.pm.data_store.unused_keys - self.pm.ignore_grad_keys @@ -199,19 +199,19 @@ def generate_evaluate(self) -> tuple[c_ast.FunctionDef, set[str]]: unused_keys = self.pm.data_store.unused_keys cached_data_keys = self.pm.data_store.cached_data.keys() - for output_key in self.pm._flat_graph.topological_order: + for output_key in self.pm.flat_graph.topological_order: # Staticly infered and unused model will not be added if output_key in (cached_data_keys | unused_keys): continue - model = self.pm._flat_graph.get_model(output_key) - inputs = self.pm._flat_graph.get_source_keys(output_key) + model = self.pm.flat_graph.get_model(output_key) + inputs = self.pm.flat_graph.get_source_keys(output_key) # In C backend we need to pass output array as first argument inputs = [output_key] + inputs # Create primitive call - p_call = self.create_primitive_call(model._formula_key, inputs) + p_call = self.create_primitive_call(model.formula_key, inputs) fn_body.append(p_call) used_keys.add(output_key) @@ -238,13 +238,13 @@ def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: set(), self.pm._output_keys, all_ignored_keys, update_graph=False ) - for output_key in reversed(self.pm._flat_graph.topological_order): + for output_key in reversed(self.pm.flat_graph.topological_order): # Staticly infered and unused model will not be added if output_key in all_ignored_keys: continue - model = self.pm._flat_graph.get_model(output_key) - inputs = self.pm._flat_graph.get_source_keys(output_key) + model = self.pm.flat_graph.get_model(output_key) + inputs = self.pm.flat_graph.get_source_keys(output_key) # Assume all inputs are Array grad_inputs = [input_key + "_grad" for input_key in inputs] @@ -257,7 +257,7 @@ def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: # Create primitive call p_call = self.create_primitive_call( - model._formula_key + self.BACKWARD_FN_SUFFIX, fn_inputs + model.formula_key + self.BACKWARD_FN_SUFFIX, fn_inputs ) fn_body.append(p_call) diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index 8e06ce41..37919a27 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -151,7 +151,7 @@ def evaluate_gradients_wrapper_manualgrad( # Initialize gradients as zero with corresponding shapes. gradients: dict[str, np.ndarray[Any, Any]] = {} for key in ( - self.pm._flat_graph.all_keys + self.pm.flat_graph.all_keys - self.pm.data_store.all_static_keys - self.pm.data_store.unused_keys - self.pm.ignore_grad_keys @@ -165,12 +165,12 @@ def evaluate_gradients_wrapper_manualgrad( out_data = key_cache["output"] else: # Removed primitives, to take shape of output take input shape - _key = self.pm._flat_graph.get_source_keys( + _key = self.pm.flat_graph.get_source_keys( key, include_outputs=True )[0] _key_cache = cached_data.get(_key + "_cache", {}) assert isinstance(_key_cache, dict) - if _key in self.pm._input_keys: + if _key in self.pm.input_keys: out_data = params[_key] else: out_data = _key_cache["output"] @@ -211,11 +211,11 @@ def evaluate_gradients_wrapper_manualgrad( return self.post_process_fns(eval_fn, grad_fn, jit) # type: ignore def get_primitive_details(self, output_key: str): - model = self.pm._flat_graph.get_model(output_key) + model = self.pm.flat_graph.get_model(output_key) - global_input_keys = self.pm._flat_graph.get_source_keys(output_key) + global_input_keys = self.pm.flat_graph.get_source_keys(output_key) global_input_keys += [self.get_cache_name(output_key, model)] - local_input_keys = list(model._input_keys) + ["cache"] + local_input_keys = list(model.input_keys) + ["cache"] return model, global_input_keys, local_input_keys @@ -308,7 +308,7 @@ def generate_evaluate_gradients( for key in all_ignored_keys if key in self.pm.data and isinstance(self.pm.data[key], Tensor) - and find_intersection_type(self.pm.data[key]._type, float) + and find_intersection_type(self.pm.data[key].type, float) } strict_ignored_keys = all_ignored_keys - weak_ignored_keys @@ -326,17 +326,17 @@ def generate_evaluate_gradients( self.pm._output_keys - possible_loss_keys - ( - self.pm._flat_graph.all_source_keys + self.pm.flat_graph.all_source_keys | { value - for key, value in self.pm._flat_graph.output_dict.items() + for key, value in self.pm.flat_graph.output_dict.items() if key != value } ) ) # Move gradients back for keys in alias_map(pruned or optimized out keys) - for target_key, source_key in self.pm._flat_graph.output_dict.items(): + for target_key, source_key in self.pm.flat_graph.output_dict.items(): if target_key == source_key: continue if target_key not in ignore_grad_keys: @@ -365,15 +365,15 @@ def generate_evaluate_gradients( assign = ast.AugAssign(target=target, op=ast.Add(), value=source) function_body.append(assign) - for output_key in reversed(self.pm._flat_graph.topological_order): + for output_key in reversed(self.pm.flat_graph.topological_order): if output_key in ignore_grad_keys: continue # Iterate over Primitive models in topological order to add their formula. - model = self.pm._flat_graph.get_model(output_key) + model = self.pm.flat_graph.get_model(output_key) - output_key = self.pm._flat_graph.connections[output_key].key - inputs = list(self.pm._flat_graph.get_source_keys(output_key)) + output_key = self.pm.flat_graph.connections[output_key].key + inputs = list(self.pm.flat_graph.get_source_keys(output_key)) # Check if the model is disposable. if model.disposable: @@ -384,14 +384,14 @@ def generate_evaluate_gradients( # Get primitive function inputs order primitive_function = ( - self.backend.primitive_function_dict[model._formula_key] - if model._formula_key in self.backend.primitive_function_dict - else self.backend.registered_primitives[model._formula_key] + self.backend.primitive_function_dict[model.formula_key] + if model.formula_key in self.backend.primitive_function_dict + else self.backend.registered_primitives[model.formula_key] ) local_to_global_dict = { key: value for key, value in zip( - list(model._input_keys) + ["cache"], inputs, strict=False + list(model.input_keys) + ["cache"], inputs, strict=False ) } args, kwargs = prepare_function_args( @@ -405,7 +405,7 @@ def generate_evaluate_gradients( # Get local keys in ordered global_to_local_dict: dict[str, list[str]] = {} for key, value in zip( - list(model._input_keys) + ["cache"], inputs, strict=False + list(model.input_keys) + ["cache"], inputs, strict=False ): global_to_local_dict.setdefault(value, []) global_to_local_dict[value].append(key) @@ -420,7 +420,7 @@ def generate_evaluate_gradients( ] # Reorder global keys wrt primitive evaluate function local keys order - model_local_inputs = list(model._input_keys) + ["cache"] + model_local_inputs = list(model.input_keys) + ["cache"] _inputs = [ inputs[model_local_inputs.index(local_key)] for local_key in primitive_local_inputs @@ -446,7 +446,7 @@ def generate_evaluate_gradients( if grad_fn is None: raise NotImplementedError( - f"Primitive {model._formula_key} does not have vjp " + f"Primitive {model.formula_key} does not have vjp " "implementation!" ) @@ -519,7 +519,7 @@ def generate_evaluate_gradients( for key in sorted(used_keys): if ( key - in self.pm._flat_graph.all_target_keys + in self.pm.flat_graph.all_target_keys | self.pm.data_store.cached_data.keys() ): dict_type = "cache" diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 3b51549d..97750ffa 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -156,7 +156,7 @@ def exec_generated_code( ) module = importlib.util.module_from_spec(module_spec) # type: ignore module_spec.loader.exec_module(module) # type: ignore - eval_fn = module.evaluate + eval_fn: EvaluateType[DataType] = module.evaluate eval_grad_fn = ( module.evaluate_gradients if hasattr(module, "evaluate_gradients") @@ -224,15 +224,15 @@ def post_process_fns( isinstance(self.pm.backend, ParallelBackend) and self.pm.backend.n_devices > 1 ): - self.pm.backend._register_callable(eval_fn, "eval_fn", jit) + self.pm.backend.register_callable(eval_fn, "eval_fn", jit) if not self.pm.inference: assert grad_fn is not None, "Gradient function is not defined!" assert ( evaluate_all_fn is not None ), "Evaluate all function is not defined!" - self.pm.backend._register_callable(grad_fn, "eval_grad_fn", jit) - self.pm.backend._register_callable(evaluate_all_fn, "eval_all_fn", jit) + self.pm.backend.register_callable(grad_fn, "eval_grad_fn", jit) + self.pm.backend.register_callable(evaluate_all_fn, "eval_all_fn", jit) elif jit and not self.pm.backend.is_manualgrad: eval_fn = self.pm.backend.jit(eval_fn) @@ -296,10 +296,10 @@ def generate_imports(self): return imports def get_primitive_details(self, output_key: str): - model = self.pm._flat_graph.get_model(output_key) + model = self.pm.flat_graph.get_model(output_key) - global_input_keys = self.pm._flat_graph.get_source_keys(output_key) - local_input_keys = list(model._input_keys) + global_input_keys = self.pm.flat_graph.get_source_keys(output_key) + local_input_keys = list(model.input_keys) return model, global_input_keys, local_input_keys @@ -329,20 +329,20 @@ def generate_evaluate(self): return_values: list[ast.expr] = [] used_keys: set[str] = set() - used_keys |= set(self.pm._flat_graph.output_dict.values()) + used_keys |= set(self.pm.flat_graph.output_dict.values()) unused_keys = self.pm.data_store.unused_keys cached_data_keys = self.pm.data_store.cached_data.keys() discarded_keys = self.pm.discarded_keys # TODO: Consider is this necessary? # Iterate over Primitive models in topological order to add their formula. - for output_key in self.pm._flat_graph.topological_order: + for output_key in self.pm.flat_graph.topological_order: # Staticly infered and unused model will not be added if output_key in (cached_data_keys | unused_keys | discarded_keys): continue model, g_input_keys, l_input_keys = self.get_primitive_details(output_key) - formula_key = model._formula_key + formula_key = model.formula_key primitive_function = ( self.pm.backend.primitive_function_dict[formula_key] @@ -370,7 +370,7 @@ def generate_evaluate(self): dict_type = "cache" elif key in self.pm.data_store.runtime_static_keys: dict_type = "data" - elif key not in self.pm._flat_graph.all_target_keys: + elif key not in self.pm.flat_graph.all_target_keys: dict_type = "params" else: continue @@ -381,7 +381,7 @@ def generate_evaluate(self): for output_key in self.pm.output_keys: # TODO: give an api to get outputdict return_values.append( - ast.Name(self.pm._flat_graph.output_dict[output_key], ast.Load()) + ast.Name(self.pm.flat_graph.output_dict[output_key], ast.Load()) ) return_body: list[ast.stmt] = [ @@ -434,7 +434,7 @@ def append_inputs(self, input_body: list[ast.stmt], key: str, dict_type: str): else: val = key - if dict_type != "cache" or (key not in self.pm._flat_graph.all_target_keys): + if dict_type != "cache" or (key not in self.pm.flat_graph.all_target_keys): input_body.append( ast.Assign( targets=[ast.Name(id=val, ctx=ast.Store())], diff --git a/mithril/framework/codegen/torch_gen.py b/mithril/framework/codegen/torch_gen.py index be1f3458..1b326de2 100644 --- a/mithril/framework/codegen/torch_gen.py +++ b/mithril/framework/codegen/torch_gen.py @@ -53,7 +53,7 @@ def call_primitive( if ( formula_key in self.backend.array_creation_funcs - and self.backend._raw_device_mesh is not None + and self.backend.get_raw_device_mesh() is not None ): # Import device mesh and create base device mesh only once! if not self.is_parallel_defined: @@ -69,7 +69,7 @@ def call_primitive( attr="device_meshes", ctx=ast.Load(), ), - slice=ast.Constant(value=self.backend._raw_device_mesh), + slice=ast.Constant(value=self.backend.get_raw_device_mesh()), ctx=ast.Load(), ) base_device_mesh_assgn = ast.Assign( diff --git a/mithril/framework/codegen/utils.py b/mithril/framework/codegen/utils.py index 74d2fdcf..d92977d3 100644 --- a/mithril/framework/codegen/utils.py +++ b/mithril/framework/codegen/utils.py @@ -28,7 +28,7 @@ def partial_array_creation_func( # We don't need device in manulgrad(Numpy) if not backend.is_manualgrad: kwargs.append( - ast.keyword(arg="device", value=ast.Constant(value=backend._device)) + ast.keyword(arg="device", value=ast.Constant(value=backend.get_device())) ) partial_fn_call = ast.Call( diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 1cf79e3d..aa7c48d2 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -51,7 +51,7 @@ ) __all__ = [ - "_get_shapes", + "get_shapes", "NOT_GIVEN", "TBD", "IOKey", @@ -593,7 +593,7 @@ def __ior__(self, other: Updates) -> Updates: return self -def _get_shapes( +def get_shapes( data_dict: dict[str, Tensor | Scalar], uniadic_keys: dict[UniadicRecord, str] | None = None, varadic_keys: dict[Variadic, str] | None = None, @@ -625,8 +625,18 @@ def _get_shapes( class BaseData(Generic[T]): + @overload + def __init__( + self: BaseData[_TensorTypes], type: T, is_tensor: Literal[True] = True + ) -> None: ... + + @overload + def __init__( + self: BaseData[ScalarType], type: T, is_tensor: Literal[False] = False + ) -> None: ... + def __init__(self, type: T, is_tensor: bool = False) -> None: - self._type: T = type + self.type: T = type self.shape: ShapeNode | None = None self.shape_constraints: set[Constraint] = set() self.type_constraints: set[Constraint] = set() @@ -636,7 +646,7 @@ def __init__(self, type: T, is_tensor: bool = False) -> None: @property def is_non_diff(self) -> bool: - return not self._differentiable + return not self.differentiable @property def is_valued(self) -> bool: @@ -646,7 +656,7 @@ def is_valued(self) -> bool: def all_constraints(self) -> set[Constraint]: return self.shape_constraints | self.type_constraints - def finalize_match(self, other: BaseData): + def finalize_match(self, other: BaseData[T]): if (typ_1 := type(other)) != (typ_2 := type(self)): raise TypeError( f"Replacement can be done for only same types. Got {typ_1} and {typ_2}" @@ -662,22 +672,22 @@ def set_type(self, typ: T) -> Updates: updates = Updates() if not self._types_equal(typ): updates.add(self, UpdateType.TYPE) # type: ignore - new_type = find_intersection_type(typ, self._type) + new_type = find_intersection_type(typ, self.type) if not new_type: raise TypeError( - f"Acceptable types are {self._type}, but {typ} type value " + f"Acceptable types are {self.type}, but {typ} type value " "is provided!" ) - self._type = new_type # type: ignore + self.type = new_type # type: ignore return updates def _types_equal(self, other_type: T) -> bool: self_type: T - if isinstance(self._type, NestedListType): - self_type = self._type.base_type + if isinstance(self.type, NestedListType): + self_type = self.type.base_type else: - self_type = self._type + self_type = self.type if isinstance(other_type, NestedListType): other_type = other_type.base_type @@ -699,18 +709,17 @@ def remove_constraint(self, constraint: Constraint): elif constraint.type == UpdateType.TYPE: self.type_constraints.discard(constraint) - def match(self, other: BaseData) -> Updates: - self._differentiable: bool + def match(self, other: BaseData[T]) -> Updates: + self.differentiable: bool updates = Updates() if self != other: updates = Updates() - updates |= self.set_type(other._type) - updates |= other.set_type(self._type) + updates |= self.set_type(other.type) + updates |= other.set_type(self.type) if other.is_tensor: - assert self.is_tensor updates |= self.match_shapes(other) - is_diff = self._differentiable | other._differentiable - self._differentiable = other._differentiable = is_diff + is_diff = self.differentiable | other.differentiable + self.differentiable = other.differentiable = is_diff if self.is_valued or other.is_valued: valued, non_valued = (self, other) if self.is_valued else (other, self) @@ -743,7 +752,7 @@ def set_value(self, value: AllValueType) -> Updates: # type: ignore[override] shape = list_shape(value) # type: ignore assert isinstance(self.shape, ShapeNode) updates |= self.shape.set_values(shape) - self._differentiable = False + self.differentiable = False return updates def find_type(self, value: AllValueType) -> type: @@ -761,7 +770,7 @@ def make_physical(self, backend: Backend[DataType], memo: dict[int, Any]): physical_data.value = epsilon_table[backend.precision][self.value] return physical_data - def match_shapes(self, other: BaseData[_TensorTypes]): + def match_shapes(self, other: BaseData[T]): assert isinstance(other.shape, ShapeNode) assert isinstance(self.shape, ShapeNode) @@ -793,7 +802,7 @@ def __init__( if not isinstance(value, ToBeDetermined): self.set_type(find_dominant_type(value)) self.value: TensorValueType = value - self._differentiable: bool = self.value is TBD + self.differentiable: bool = self.value is TBD self.shape: ShapeNode = shape # Update referees field of ShapeNode. self.shape.referees.add(self) @@ -812,7 +821,7 @@ def __init__( if not isinstance(value, ToBeDetermined | str): self.set_type(self.find_type(value)) self.value: ScalarValueType = value - self._differentiable = False + self.differentiable = False # TODO: Convert MyTensor to Tensor when Tensor and Scalar is removed @@ -826,12 +835,12 @@ def __init__(self, value: TensorValueType): # Check if a type is a specialization of MyTensor -def is_mytensor_type(type_obj) -> bool: +def is_mytensor_type(type_obj: Any) -> bool: return get_origin(type_obj) is MyTensor # Check if a type is a specialization of MyTensor -def get_mytensor_subtype(type_obj: MyTensor) -> type: +def get_mytensor_subtype(type_obj: Any) -> type: return get_args(type_obj)[0] @@ -1004,7 +1013,7 @@ def len(self): def shape(self): return ExtendTemplate(connections=[self], model="shape") - def reshape(self, shape: tuple[int, ...] | TemplateBase): + def reshape(self, shape: tuple[int | TemplateBase, ...] | TemplateBase): return ExtendTemplate(connections=[self, shape], model="reshape") def size(self, dim: int | tuple[int, ...] | TemplateBase | None = None): @@ -1141,6 +1150,26 @@ def __init__( conn = item.data if isinstance(item, Connection) else item self._connections.add(conn) + @property + def name(self): + return self._name + + @property + def value(self): + return self._value + + @property + def type(self): + return self._type + + @property + def expose(self): + return self._expose + + @property + def interval(self): + return self._interval + def __hash__(self) -> int: return hash(id(self)) @@ -1195,7 +1224,7 @@ def __eq__(self, other: object) -> bool: def set_differentiable(self, differentiable: bool = True) -> None: if isinstance(self.metadata.data, Tensor): - self.metadata.data._differentiable = differentiable + self.metadata.data.differentiable = differentiable else: raise ValueError("Scalar data can not be set as differentiable.") @@ -1205,7 +1234,7 @@ def set_differentiable(self, differentiable: bool = True) -> None: | int | float | list[int | float] - | tuple[slice | int | None | EllipsisType, ...] + | tuple[slice | int | None | EllipsisType | TemplateBase, ...] | None ) @@ -1307,15 +1336,13 @@ def add( self.connections_dict.setdefault(metadata, set()).add(self) def set_connection_type( - self, connection: ConnectionData, con_type: KeyType + self, + connection: ConnectionData, + con_type: KeyType, + safe: bool = True, ) -> None: - if con_type == KeyType.OUTPUT and connection.is_key_autogenerated: + if safe and con_type == KeyType.OUTPUT and connection.is_key_autogenerated: raise KeyError("Connection without a name cannot be set as output") - self._set_connection_type(connection, con_type) - - def _set_connection_type( - self, connection: ConnectionData, con_type: KeyType - ) -> None: key = connection.key for _type in KeyType: if _type == con_type: @@ -1328,7 +1355,7 @@ def remove_connection(self, connection: ConnectionData) -> None: self._connection_dict[_type].pop(connection.key, None) def get_data(self, key: str) -> Scalar | Tensor: - return self._get_metadata(key).data + return self.get_metadata(key).data def get_non_diff_keys(self): return {key for key, conn in self.all.items() if conn.metadata.data.is_non_diff} @@ -1354,16 +1381,16 @@ def get_con_by_metadata(self, key: IOHyperEdge) -> ConnectionData | None: def get_cons_by_metadata(self, key: IOHyperEdge): return self.metadata_dict.get(key) - def _get_metadata(self, key: str) -> IOHyperEdge: + def get_metadata(self, key: str) -> IOHyperEdge: if (con := self.get_connection(key)) is not None: return con.metadata raise KeyError(f"Key '{key}' is not found in connections.") def get_key_origin(self, key: str) -> str | None: - return self._get_metadata(key).key_origin + return self.get_metadata(key).key_origin def get_shape_node(self, key: str) -> ShapeNode: - data = self._get_metadata(key).data + data = self.get_metadata(key).data if not isinstance(data, Tensor): raise ValueError("'Shape cannot be set for scalar type values'") return data.shape @@ -1376,7 +1403,7 @@ def extract_metadata(self, key: str | Connection) -> IOHyperEdge: # Extract the key from the Connection object. metadata = key.metadata else: - metadata = self._get_metadata(key) + metadata = self.get_metadata(key) return metadata @@ -2223,7 +2250,7 @@ def merge(self, other: ShapeNode) -> Updates: for repr2 in other.reprs: for repr1 in self.reprs: # Match all reprs of other with self.reprs. - updates |= repr1._match(repr2) + updates |= repr1.match(repr2) if ( len(repr1.prefix) == len(repr2.prefix) and len(repr1.suffix) == len(repr2.suffix) @@ -2256,7 +2283,7 @@ def combine(self): # Iterate over all repr pairs and remove matching reprs. for repr, other_repr in combinations(self.reprs, 2): if repr not in same_reprs and other_repr not in same_reprs: - updates |= repr._match(other_repr) + updates |= repr.match(other_repr) if ( len(repr.prefix) == len(other_repr.prefix) and len(repr.suffix) == len(other_repr.suffix) @@ -2495,7 +2522,7 @@ def __len__(self) -> int: return len(self.prefix) + len(self.suffix) @staticmethod - def _update_uniadics( + def update_uniadics( outer_list: list[Uniadic], inner_list: list[Uniadic] ) -> Updates: updates = Updates() @@ -2568,8 +2595,8 @@ def inner_match( "Determined shape representations should have same length." ) # Match all parallel uniadics. - updates |= self._update_uniadics(self.prefix, prefix) - updates |= self._update_uniadics( + updates |= self.update_uniadics(self.prefix, prefix) + updates |= self.update_uniadics( self.reverse, suffix[::-1] if root is not None else prefix[::-1] ) if bool(root) ^ bool(self.root): @@ -2647,15 +2674,15 @@ def inner_match( ) return updates - def _match(self, other: ShapeRepr) -> Updates: + def match(self, other: ShapeRepr) -> Updates: return self.inner_match(other.prefix, other.root, other.suffix) def set_values(self, values: Sequence[int | None]) -> Updates: updates = Updates() if self.root is not None: uniadics = [Uniadic(value) for value in values] - updates |= self._update_uniadics(self.prefix, uniadics) - updates |= self._update_uniadics(self.reverse, uniadics[::-1]) + updates |= self.update_uniadics(self.prefix, uniadics) + updates |= self.update_uniadics(self.reverse, uniadics[::-1]) updates |= self.remove_variadic(uniadics) # updates -= set(uniadics) else: @@ -3426,17 +3453,17 @@ def get_summary_types( for model, model_name in name_mappings.items(): in_dict, out_dict = type_info.setdefault(model_name, ({}, {})) for key in model.conns.all: - key_mappings = model._generate_keys(include_outputs=True) + key_mappings = model.generate_keys(include_outputs=True) in_key = key_mappings.get(key, key) data = model.conns.get_data(key) pm_data = data_memo.get(id(data), data) - data_type = pm_data._type + data_type = pm_data.type if not hasattr(data_type, "__args__"): str_type = data_type.__name__ else: - sorted_type = sort_type(pm_data._type) + sorted_type = sort_type(pm_data.type) str_type = str(sorted_type) - if key in model._input_keys: + if key in model.input_keys: in_dict[in_key] = str_type else: out_dict[in_key] = str_type @@ -3452,7 +3479,7 @@ def is_type_adjustment_required(data: dict[str, Tensor | Scalar], inputs: list[s if not isinstance(left, Tensor) or not isinstance(right, Tensor): return False - rule1 = issubclass(float, left._type) and issubclass(int, right._type) - rule2 = issubclass(float, right._type) and issubclass(int, left._type) + rule1 = issubclass(float, left.type) and issubclass(int, right.type) + rule2 = issubclass(float, right.type) and issubclass(int, left.type) return rule1 | rule2 diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index dceff75e..2345b0e8 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -161,7 +161,7 @@ def general_tensor_type_constraint(*args: Tensor): union_types: set[tuple[Tensor, UnionType]] = set() # Set all different types and also Union types in input args. for arg in inputs: - typ = arg._type + typ = arg.type arg_types.add(typ) if isinstance(typ, UnionType): union_types.add((arg, typ)) @@ -174,29 +174,29 @@ def general_tensor_type_constraint(*args: Tensor): raise TypeError(f"Possible Unsupported type(s) ({unsupported}) detected!") # Try reverse type inference first. - if not isinstance(output._type, UnionType): + if not isinstance(output.type, UnionType): # Means output has a definite type (int, float or bool). - out_exists = output._type in arg_types + out_exists = output.type in arg_types related_unions = { - pair for pair in union_types if output._type in pair[1].__args__ + pair for pair in union_types if output.type in pair[1].__args__ } if not (out_exists or related_unions): # At least one of arg_types or UnionTypes must contain # output type. raise TypeError( - f"None of arguments consist of type {output._type} which is the " + f"None of arguments consist of type {output.type} which is the " "exact output type!" ) elif not out_exists and len(related_unions) == 1: # If only one of them contains output type, enforce this union # type to be same as output type. arg = related_unions.pop()[0] - updates |= arg.set_type(output._type) + updates |= arg.set_type(output.type) status = True # Update Union type arguments. for pair in related_unions: arg, arg_type = pair - new_type = _reduce_union_type(output._type, arg_type) + new_type = _reduce_union_type(output.type, arg_type) if new_type is not None: uni_type = create_union_type(*new_type) assert not isinstance(uni_type, GenericAlias) @@ -205,10 +205,10 @@ def general_tensor_type_constraint(*args: Tensor): # If any one of inputs became same type as output, set # status True. for pair in related_unions: - if pair[0]._type == output._type: + if pair[0].type == output.type: status = True break - elif output._type == int | bool: + elif output.type == int | bool: if float in arg_types: raise TypeError( "One of arguments value is float which is not possible when output " @@ -218,7 +218,7 @@ def general_tensor_type_constraint(*args: Tensor): # We can eliminate any float possibility from Union type args. for pair in union_types: arg, arg_type = pair - new_type = _reduce_union_type(output._type, arg_type) + new_type = _reduce_union_type(output.type, arg_type) if new_type is not None: uni_type = create_union_type(*new_type) assert not isinstance(uni_type, GenericAlias) @@ -246,12 +246,12 @@ def general_tensor_type_constraint(*args: Tensor): if not union_types: out_type = bool elif all_possible_types.issuperset({float, int}): - if output._type != float | int | bool: + if output.type != float | int | bool: out_type = float | int | bool elif all_possible_types.issuperset({float}): - if output._type != float | bool: + if output.type != float | bool: out_type = float | bool - elif all_possible_types.issuperset({int}) and output._type != int | bool: + elif all_possible_types.issuperset({int}) and output.type != int | bool: out_type = int | bool elif int | float in arg_types: out_type = int | float @@ -275,22 +275,22 @@ def floor_divide_type_constraint( # constrain its type to float | int. updates |= output.set_type(int | float) # Try reverse type inference first. - if output._type is int: + if output.type is int: # Only possible when numerator and denominator are integers or booleans. updates |= numerator.set_type(int | bool) updates |= denominator.set_type(int | bool) status = True - elif output._type is float: + elif output.type is float: # At least one of inputs is float. if ( - isinstance(numerator._type, UnionType) - and float not in numerator._type.__args__ + isinstance(numerator.type, UnionType) + and float not in numerator.type.__args__ ): updates |= denominator.set_type(float) status = True elif ( - isinstance(denominator._type, UnionType) - and float not in denominator._type.__args__ + isinstance(denominator.type, UnionType) + and float not in denominator.type.__args__ ): updates |= numerator.set_type(float) status = True @@ -301,8 +301,8 @@ def scalar_slice_type_constraint( output: Scalar, input: Scalar, start: Scalar, stop: Scalar, step: Scalar ): updates = Updates() - output_type = output._type - input_type = input._type + output_type = output.type + input_type = input.type assert ( isinstance(start.value, ToBeDetermined) @@ -380,7 +380,7 @@ def scalar_slice_type_constraint( else: raise TypeError("Inferred types does not match in slice constraints!") - status = not is_union(output._type) + status = not is_union(output.type) return status, updates @@ -510,10 +510,10 @@ def scalar_item_reduce_input_type( def scalar_item_type_constraint(output: Scalar, input: Scalar, index: Scalar): updates = Updates() - assert not isinstance(input._type, NestedListType) - assert not isinstance(output._type, NestedListType) - input_type = input._type - output_type = output._type + assert not isinstance(input.type, NestedListType) + assert not isinstance(output.type, NestedListType) + input_type = input.type + output_type = output.type index_value = index.value assert isinstance(index_value, ToBeDetermined) or type(index_value) is int @@ -537,7 +537,7 @@ def scalar_item_type_constraint(output: Scalar, input: Scalar, index: Scalar): # extract all possibilites and put it in to a list # TODO: This part should take NestedListType into account. - args = input._type.__args__ if isinstance(input._type, UnionType) else [input._type] + args = input.type.__args__ if isinstance(input.type, UnionType) else [input.type] # Do the forward inference in all types in args, then make Union types = [ @@ -547,7 +547,7 @@ def scalar_item_type_constraint(output: Scalar, input: Scalar, index: Scalar): updates |= output.set_type(inferred_out_type) - status = not is_union(output._type) + status = not is_union(output.type) return status, updates @@ -586,36 +586,36 @@ def slice_constraints(output: Scalar, start: Scalar, stop: Scalar, step: Scalar) def tensor_to_list_type_constraint(output: Scalar, input: Tensor): - status = not is_union(output._type) + status = not is_union(output.type) updates = Updates() assert input._temp_shape is not None in_shape: ShapeRepr = input._temp_shape assert ( - output._type is list - or output._type is float - or output._type is int - or output._type is bool - or isinstance(output._type, NestedListType | UnionType) - or (isinstance(output._type, GenericAlias) and output._type.__origin__ is list) + output.type is list + or output.type is float + or output.type is int + or output.type is bool + or isinstance(output.type, NestedListType | UnionType) + or (isinstance(output.type, GenericAlias) and output.type.__origin__ is list) ) # If input type is UnionType, try to constrain it using output type - if get_origin(input._type) == UnionType and ( - out_types := find_list_base_type(output._type) # type: ignore # (MyPy bug) + if get_origin(input.type) == UnionType and ( + out_types := find_list_base_type(output.type) # type: ignore # (MyPy bug) ): possible_input_types = find_intersection_type( - input._type, create_union_type(*out_types) + input.type, create_union_type(*out_types) ) if not possible_input_types: raise TypeError( - f"Input type {input._type} is not compatible with output type " - f"{output._type}!" + f"Input type {input.type} is not compatible with output type " + f"{output.type}!" ) assert not isinstance(possible_input_types, NestedListType) updates |= input.set_type(possible_input_types) # Create the base same as input type - base = input._type + base = input.type if in_shape.root is None: for _ in range(len(in_shape.prefix + in_shape.suffix)): # recursively cover list with base equal to number of all determined @@ -628,9 +628,7 @@ def tensor_to_list_type_constraint(output: Scalar, input: Tensor): updates |= output.set_type(base) if in_shape.root is not None: - status = not ( - is_union(output._type) or isinstance(output._type, NestedListType) - ) + status = not (is_union(output.type) or isinstance(output.type, NestedListType)) else: status = True @@ -639,7 +637,7 @@ def tensor_to_list_type_constraint(output: Scalar, input: Tensor): def reduce_type_constraint(output: Tensor, input: Tensor): updates = Updates() - input_type = input._type + input_type = input.type possible_output_types: list[type[int] | type[float] | type[bool]] = [] @@ -657,14 +655,14 @@ def reduce_type_constraint(output: Tensor, input: Tensor): updates |= output.set_type(union_output_types) ### Reverse Inference ### - if output._type is float: + if output.type is float: # if output type is float, it is guaranteed that input will be float updates |= input.set_type(float) - elif output._type is int: + elif output.type is int: # if output type is int, input should either be int or bool updates |= input.set_type(bool | int) - status = not isinstance(output._type, UnionType) + status = not isinstance(output.type, UnionType) return status, updates @@ -1413,8 +1411,8 @@ def reduce_constraints( axis_val = (axis_val,) elif axis_val is None: if keepdim_val is False: - updates |= input_shape._update_uniadics(input_shape.prefix, []) - updates |= output_shape._update_uniadics(output_shape.reverse, []) + updates |= input_shape.update_uniadics(input_shape.prefix, []) + updates |= output_shape.update_uniadics(output_shape.reverse, []) if output_shape.root is not None: updates |= output_shape.remove_variadic([]) elif not isinstance(axis_val, tuple): @@ -1581,10 +1579,10 @@ def reduce_constraints( filtered_var_replacement: list[Uniadic] = list( filter(None, var_replacement) ) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics( output_shape.prefix, filtered_var_replacement ) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics( output_shape.reverse, filtered_var_replacement[::-1] ) updates |= output_shape.remove_variadic(filtered_var_replacement) @@ -1636,8 +1634,8 @@ def reduce_constraints( updates |= next(out_iter).match(replacement) else: input_uniadics.append(next(out_iter)) - updates |= input_shape._update_uniadics(input_shape.prefix, input_uniadics) - updates |= input_shape._update_uniadics( + updates |= input_shape.update_uniadics(input_shape.prefix, input_uniadics) + updates |= input_shape.update_uniadics( input_shape.reverse, input_uniadics[::-1] ) updates |= input_shape.remove_variadic(input_uniadics) @@ -1846,10 +1844,10 @@ def reverse_constraints( if axes_val is None: if output_shape.root is None: # TODO Maybe we should embed uniadic updates in remove_variadic - updates |= input_shape._update_uniadics( + updates |= input_shape.update_uniadics( input_shape.prefix, output_shape.reverse ) - updates |= input_shape._update_uniadics( + updates |= input_shape.update_uniadics( input_shape.reverse, output_shape.prefix ) if input_shape.root is not None: @@ -1858,10 +1856,10 @@ def reverse_constraints( raise ValueError("Shape mismatch in Transpose model") status = True if input_shape.root is None: - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics( output_shape.prefix, input_shape.reverse ) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics( output_shape.reverse, input_shape.prefix ) if output_shape.root is not None: @@ -1877,11 +1875,11 @@ def reverse_constraints( in_unis = [Uniadic() for _ in range(len(a_val))] out_unis = [in_unis[axis] for axis in a_val] - updates |= input_shape._update_uniadics(input_shape.prefix, in_unis) - updates |= input_shape._update_uniadics(input_shape.reverse, in_unis[::-1]) + updates |= input_shape.update_uniadics(input_shape.prefix, in_unis) + updates |= input_shape.update_uniadics(input_shape.reverse, in_unis[::-1]) - updates |= output_shape._update_uniadics(output_shape.prefix, out_unis) - updates |= output_shape._update_uniadics(output_shape.reverse, out_unis[::-1]) + updates |= output_shape.update_uniadics(output_shape.prefix, out_unis) + updates |= output_shape.update_uniadics(output_shape.reverse, out_unis[::-1]) if input_shape.root is not None: updates |= input_shape.remove_variadic(in_unis) @@ -2497,10 +2495,10 @@ def arange_constraints( elif (min_dims := len(output_shape)) <= 1: if val > 0: out_uniadic = [Uniadic()] - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics( output_shape.prefix, out_uniadic ) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics( output_shape.reverse, out_uniadic ) updates |= output_shape.remove_variadic(out_uniadic) @@ -2545,8 +2543,8 @@ def randn_constraints(output: Tensor, shape: Scalar) -> ConstrainResultType: f"must have exactly {len(shape_val)} dim(s)." ) out_uniadics = [Uniadic(dim) for dim in shape_val] - updates |= output_shape._update_uniadics(output_shape.prefix, out_uniadics) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics(output_shape.prefix, out_uniadics) + updates |= output_shape.update_uniadics( output_shape.reverse, out_uniadics[::-1] ) updates |= output_shape.remove_variadic(out_uniadics) @@ -2593,8 +2591,8 @@ def broadcast_to_constraints( f"must have exactly {len(shape_val)} dim(s)." ) out_uniadics = [Uniadic(dim) for dim in shape_val] - updates |= output_shape._update_uniadics(output_shape.prefix, out_uniadics) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics(output_shape.prefix, out_uniadics) + updates |= output_shape.update_uniadics( output_shape.reverse, out_uniadics[::-1] ) updates |= output_shape.remove_variadic(out_uniadics) @@ -2684,8 +2682,8 @@ def reshape_constraints( out_uniadics = [ Uniadic(val) if val != -1 else Uniadic() for val in shape_val ] - updates |= output_shape._update_uniadics(output_shape.prefix, out_uniadics) - updates |= output_shape._update_uniadics( + updates |= output_shape.update_uniadics(output_shape.prefix, out_uniadics) + updates |= output_shape.update_uniadics( output_shape.reverse, out_uniadics[::-1] ) updates |= output_shape.remove_variadic(out_uniadics) @@ -2757,8 +2755,8 @@ def reshape_constraints( values: list[int | None] | tuple[int | None, ...] = [ uni.value for uni in output_shape.prefix ] - assert isinstance(shape._type, GenericAlias) - if shape._type.__origin__ is tuple: + assert isinstance(shape.type, GenericAlias) + if shape.type.__origin__ is tuple: values = tuple(values) # TODO: This update assumes no -1 is given in shapes. However, # situations may occur where shape is given with -1. @@ -3151,7 +3149,7 @@ def swap_axes_constraints( updates |= other[axis2_val].match(non_variadic[axis1_val]) else: - updates |= other._match(non_variadic) + updates |= other.match(non_variadic) other[axis1_val], other[axis2_val] = other[axis2_val], other[axis1_val] status = True @@ -3200,9 +3198,9 @@ def to_tensor_constraints(output: Tensor, input: Scalar) -> ConstrainResultType: updates |= output.set_type(typ) updates.add(output, update_type=UpdateType.TYPE) elif isinstance(input_val, float | int): - assert isinstance(input._type, type(int) | type(float)) + assert isinstance(input.type, type(int) | type(float)) shape = [] - updates |= output.set_type(input._type) + updates |= output.set_type(input.type) updates.add(output, update_type=UpdateType.TYPE) if output_shape.root is None: if len(shape) != len(output_shape.prefix): @@ -3719,15 +3717,15 @@ def cross_entropy_constraint( if categorical_value is not TBD: if not categorical_value: - updates |= target_shape._match(input_shape) + updates |= target_shape.match(input_shape) else: N = Uniadic() C = Uniadic() var = Variadic() in_repr = ShapeRepr([N, C], var) target_repr = ShapeRepr([N], var) - updates = input_shape._match(in_repr) - updates = target_shape._match(target_repr) + updates = input_shape.match(in_repr) + updates = target_shape.match(target_repr) status = True return status, updates diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 44a8e44e..5ffee533 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -52,9 +52,8 @@ Updates, UpdateType, Variadic, - _get_shapes, - _ShapesType, create_shape_repr, + get_shapes, ) from ..constraints import post_process_map, type_constraints @@ -63,17 +62,17 @@ @dataclass class ExtendInfo: - _model: BaseModel - _connections: dict[str, ConnectionType] + model: BaseModel + connections: dict[str, ConnectionType] def __post_init__(self): - external_keys = set(self._model.external_keys) - if self._model.canonical_input is not NOT_AVAILABLE: - external_keys.add(self._model.canonical_input.key) - if self._model.canonical_output is not NOT_AVAILABLE: - external_keys.add(self._model.canonical_output.key) + external_keys = set(self.model.external_keys) + if self.model.canonical_input is not NOT_AVAILABLE: + external_keys.add(self.model.canonical_input.key) + if self.model.canonical_output is not NOT_AVAILABLE: + external_keys.add(self.model.canonical_output.key) - for key in self._connections: + for key in self.connections: if key not in external_keys: raise KeyError(f"Key '{key}' is not a valid key for the model!") @@ -105,7 +104,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: case str(): kwargs[key] = IOKey(con, value=val, expose=False) case IOKey(): - if con._value is not TBD and con._value != val: + if con.value is not TBD and con.value != val: raise ValueError( f"Given IOKey for local key: '{key}' is not valid!" ) @@ -115,7 +114,7 @@ def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: for item in con._connections ] kwargs[key] = IOKey( - name=con._name, + name=con.name, value=val, shape=con._shape, type=con._type, @@ -176,7 +175,7 @@ def jittable(self) -> bool: return self._jittable @property - def shapes(self) -> _ShapesType: + def shapes(self): return self.get_shapes() @property @@ -184,7 +183,7 @@ def external_keys(self): return self.conns.io_keys @property - def _input_keys(self): + def input_keys(self): return self.conns.input_keys @property @@ -212,7 +211,7 @@ def _get_outermost_parent(self): model = model.parent return model - def _generate_keys( + def generate_keys( self, symbolic: bool = True, include_internals: bool = True, @@ -428,8 +427,8 @@ def get_shapes( var_keys: dict[Variadic, str] | None = None, symbolic: bool = True, verbose: bool = False, - ) -> _ShapesType: - return _get_shapes( + ) -> Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None]: + return get_shapes( data_dict={ key: value.metadata.data for key, value in self.conns.all.items() }, @@ -437,7 +436,7 @@ def get_shapes( varadic_keys=var_keys, symbolic=symbolic, verbose=verbose, - key_mappings=self._generate_keys(include_outputs=True), + key_mappings=self.generate_keys(include_outputs=True), ) def _set_constraint( @@ -505,7 +504,7 @@ def set_canonical_input(self, given_conn: str | Connection): conn = self.conns.get_con_by_metadata(conn.metadata) - if conn not in self.dependency_map._local_input_dependency_map: + if conn not in self.dependency_map.local_input_dependency_map: raise ValueError( "To set a connection as canonical input, connection must be an " "input connection!" @@ -523,7 +522,7 @@ def set_canonical_output(self, given_conn: str | Connection): conn = self.conns.get_con_by_metadata(conn.metadata) - if conn not in self.dependency_map._local_output_dependency_map: + if conn not in self.dependency_map.local_output_dependency_map: raise ValueError( "To set a connection as canonical output, connection must be an " "output connection!" @@ -564,7 +563,7 @@ def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: return updates def get_models_in_topological_order(self): - dependency_map = self.dependency_map._local_output_dependency_map + dependency_map = self.dependency_map.local_output_dependency_map graph = { info[0]: OrderedSet( [dependency_map[spec][0] for spec in info[1] if spec in dependency_map] @@ -618,12 +617,12 @@ def __init__(self, connections: Connections) -> None: ] = {} # Stores releation between local input keys to dependent local # output connections - self._local_input_dependency_map: dict[ + self.local_input_dependency_map: dict[ ConnectionData, list[tuple[BaseModel, OrderedSet[ConnectionData]]] ] = {} # Stores releation between local output keys to dependent local # input connections - self._local_output_dependency_map: dict[ + self.local_output_dependency_map: dict[ ConnectionData, tuple[BaseModel, OrderedSet[ConnectionData]] ] = {} @@ -641,7 +640,7 @@ def add_model_dag(self, model: BaseModel, model_dag: dict[str, ConnectionData]): if model_dag.get(conn.key) is not None ] ) - self._local_input_dependency_map.setdefault(conn, []).append( + self.local_input_dependency_map.setdefault(conn, []).append( (model, specs) ) updated_conns.add(conn) @@ -655,10 +654,10 @@ def add_model_dag(self, model: BaseModel, model_dag: dict[str, ConnectionData]): if model_dag.get(conn.key) is not None ] ) - self._local_output_dependency_map[conn] = (model, specs) + self.local_output_dependency_map[conn] = (model, specs) updated_conns.add(conn) - self._cache_internal_references(conn, specs) + self.cache_internal_references(conn, specs) if self.look_for_cyclic_connection(conn, specs): raise Exception( @@ -666,10 +665,10 @@ def add_model_dag(self, model: BaseModel, model_dag: dict[str, ConnectionData]): f"{[spec.key for spec in specs]} key(s)!" ) - self._update_globals(updated_conns) + self.update_globals(updated_conns) # Caches extended connections to avoid traverse - def _cache_internal_references( + def cache_internal_references( self, output_conn: ConnectionData, dependent_conns: OrderedSet[ConnectionData] ): # Be sure all input and output keys has cache entry @@ -785,7 +784,7 @@ def get_dependent_output_conns(self, key: str) -> OrderedSet[ConnectionData]: def update_all_keys(self): # This method is used in freeze, because in freeze dependencies changed # without updating dependency map. - self._update_globals( + self.update_globals( OrderedSet(self.conns.input_connections) | OrderedSet(self.conns.output_connections) ) @@ -821,7 +820,7 @@ def _get_from_output_cache(self, conn: ConnectionData): return dependent_conns # Update global dependency maps wrt given connections - def _update_globals(self, updated_conns: OrderedSet[ConnectionData]): + def update_globals(self, updated_conns: OrderedSet[ConnectionData]): for input_conn in self.conns.input_connections: self._global_input_dependency_map.setdefault(input_conn, OrderedSet()) @@ -871,7 +870,7 @@ def get_input_key_dependency(self, key: str): specs = OrderedSet( [ key - for item in self._local_input_dependency_map[given_conn] + for item in self.local_input_dependency_map[given_conn] for key in item[1] if key in self.conns.output_keys ] @@ -881,7 +880,7 @@ def get_input_key_dependency(self, key: str): key_stack = OrderedSet( [ spec - for item in self._local_input_dependency_map[given_conn] + for item in self.local_input_dependency_map[given_conn] for spec in item[1] if spec not in specs ] @@ -896,11 +895,11 @@ def get_input_key_dependency(self, key: str): OrderedSet( [ spec - for item in self._local_input_dependency_map[conn_data] + for item in self.local_input_dependency_map[conn_data] for spec in item[1] ] ) - if conn_data in self._local_input_dependency_map + if conn_data in self.local_input_dependency_map else OrderedSet() ) return specs @@ -914,7 +913,7 @@ def get_output_key_dependency(self, key: str): specs = OrderedSet( [ key - for key in self._local_output_dependency_map[given_conn][1] + for key in self.local_output_dependency_map[given_conn][1] if key in self.conns.input_keys ] ) @@ -922,7 +921,7 @@ def get_output_key_dependency(self, key: str): key_stack = OrderedSet( [ spec - for spec in self._local_output_dependency_map[given_conn][1] + for spec in self.local_output_dependency_map[given_conn][1] if spec not in specs ] ) @@ -933,8 +932,8 @@ def get_output_key_dependency(self, key: str): # key_stack.update(self.dependency_map.get(key.key, OrderedSet())) # TODO: add test checking the while key_stack |= ( - self._local_output_dependency_map[conn_data][1] - if conn_data in self._local_output_dependency_map + self.local_output_dependency_map[conn_data][1] + if conn_data in self.local_output_dependency_map else OrderedSet() ) return specs @@ -947,9 +946,9 @@ def look_for_cyclic_connection( return True else: for conn in conns: - if conn in self._local_output_dependency_map: + if conn in self.local_output_dependency_map: return self.look_for_cyclic_connection( - target_conn, self._local_output_dependency_map[conn][1] + target_conn, self.local_output_dependency_map[conn][1] ) return False diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index 4235719a..c12027ec 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -160,7 +160,7 @@ def __init__( super().__init__(formula_key="to_tuple", name=name, **key_definitions) self._set_constraint( fn=to_tuple_constraints, - keys=[PrimitiveModel.output_key] + [key for key in self._input_keys], + keys=[PrimitiveModel.output_key] + [key for key in self.input_keys], ) @@ -802,7 +802,7 @@ def __init__(self, n: int, name: str | None = None, **kwargs) -> None: self._set_constraint( fn=to_list_constraints, - keys=[PrimitiveModel.output_key] + [key for key in self._input_keys], + keys=[PrimitiveModel.output_key] + [key for key in self.input_keys], ) diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index aaf8f79b..22b91f80 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -165,7 +165,7 @@ def __init__( ) -> None: self.dag: dict[BaseModel, dict[str, ConnectionData]] = {} self.inter_key_count: int = 0 - self._formula_key: str | None = None + self.formula_key: str | None = None super().__init__(name=name, enforce_jit=enforce_jit) @@ -222,15 +222,15 @@ def set_outputs(self, *args: str | Connection, **kwargs: str | Connection) -> No if new_name is None: # Non-named connections. # Set connection as output and update dependency map. self.conns.set_connection_type(conn_data, KeyType.OUTPUT) - self.dependency_map._update_globals(OrderedSet({conn_data})) + self.dependency_map.update_globals(OrderedSet({conn_data})) else: # Named connections. # Create new output connection with given key name. data: Tensor | Scalar = ( - Scalar(metadata.data._type) + Scalar(metadata.data.type) if isinstance(metadata.data, Scalar) - else Tensor(metadata.data.shape, metadata.data._type) + else Tensor(metadata.data.shape, metadata.data.type) ) new_conn = self.create_connection(IOHyperEdge(data), new_name) @@ -242,7 +242,7 @@ def set_outputs(self, *args: str | Connection, **kwargs: str | Connection) -> No self.merge_connections(new_conn, conn_data) def _set_formula_key(self, formula_key: str): - self._formula_key = formula_key + self.formula_key = formula_key def _check_multi_write( self, @@ -251,12 +251,12 @@ def _check_multi_write( connection: ConnectionData, ) -> None: conn_is_output = ( - self.dependency_map._local_output_dependency_map.get(connection, None) + self.dependency_map.local_output_dependency_map.get(connection, None) is not None ) if local_connection.key in self.conns.all and connection.key in self.conns.all: local_conn_is_output = ( - self.dependency_map._local_output_dependency_map.get( + self.dependency_map.local_output_dependency_map.get( local_connection, None ) is not None @@ -308,14 +308,14 @@ def _add_connection( ) -> tuple[ConnectionData, Updates]: updates = Updates() outer_key, con_obj = None, None - is_input = local_key in model._input_keys + is_input = local_key in model.input_keys local_connection = model.conns.get_connection(local_key) assert local_connection is not None, "Connection is not found!" # Flags for use in required operations. create_connection = None set_value: ToBeDetermined | str | MainValueType | NullConnection = NOT_GIVEN match_connection = None - d_map = self.dependency_map._local_output_dependency_map + d_map = self.dependency_map.local_output_dependency_map if isinstance( given_connection, MainValueInstance | NullConnection @@ -368,8 +368,8 @@ def _add_connection( con_obj = self.conns.get_connection(outer_key) # Connection match is required. match_connection = True - if given_connection._value is not TBD: - set_value = given_connection._value + if given_connection.value is not TBD: + set_value = given_connection.value if ( not expose and is_input @@ -495,7 +495,7 @@ def _add_connection( # Set connection as input, output, latent input or # internal based on expose and is_input flag. if is_input: - if outer_key not in self._input_keys: + if outer_key not in self.input_keys: if expose: # if con_obj in self.conns.internal_connections: if con_obj in d_map: @@ -505,7 +505,7 @@ def _add_connection( # TODO: We both set given IOKey in handle_auto_conversion and here. # This causes duality and confusion in self.conns. We need to refactor # extend to disentangle this problem. We should avoid using - # self._input_keys, self._output_keys, self._latent_input_keys or + # self.input_keys, self._output_keys, self._latentinput_keys or # self.conns.all in _add_connection. # elif outer_key not in self.conns.all: @@ -554,7 +554,7 @@ def _unroll_template( connections: list[ConnectionType] = [] for idx, connection in enumerate(template.connections): if isinstance(connection, ExtendTemplate): - conn = model.conns.get_connection(list(model._input_keys)[idx]) + conn = model.conns.get_connection(list(model.input_keys)[idx]) assert conn is not None conn_type = conn.metadata.data.__class__ connections.append( @@ -570,7 +570,7 @@ def _unroll_template( **{ local_key: outer_con for local_key, outer_con in zip( - model._input_keys, connections, strict=False + model.input_keys, connections, strict=False ) }, ) @@ -604,11 +604,11 @@ def merge_connections( if connection2 in self.conns.output_connections: if con1_key not in self.conns.output_keys: self.conns.set_connection_type(connection1, KeyType.OUTPUT) - if con1_key in self._input_keys: + if con1_key in self.input_keys: self.conns.set_connection_type(main_connection1, KeyType.INTERNAL) elif ( main_connection2 in self.conns.internal_connections - and con1_key in self._input_keys + and con1_key in self.input_keys ): self.conns.set_connection_type(main_connection1, KeyType.INTERNAL) @@ -626,35 +626,35 @@ def merge_connections( for ( o_conn, key_info, - ) in self.dependency_map._local_output_dependency_map.items(): + ) in self.dependency_map.local_output_dependency_map.items(): if main_connection2 in key_info[1]: - self.dependency_map._local_output_dependency_map[o_conn][1].remove( + self.dependency_map.local_output_dependency_map[o_conn][1].remove( main_connection2 ) - self.dependency_map._local_output_dependency_map[o_conn][1].add( + self.dependency_map.local_output_dependency_map[o_conn][1].add( main_connection1 ) - if main_connection2 in self.dependency_map._local_output_dependency_map: - self.dependency_map._local_output_dependency_map[main_connection1] = ( - self.dependency_map._local_output_dependency_map.pop(main_connection2) + if main_connection2 in self.dependency_map.local_output_dependency_map: + self.dependency_map.local_output_dependency_map[main_connection1] = ( + self.dependency_map.local_output_dependency_map.pop(main_connection2) ) - if main_connection2 in self.dependency_map._local_input_dependency_map: - old_dependencies = self.dependency_map._local_input_dependency_map.pop( + if main_connection2 in self.dependency_map.local_input_dependency_map: + old_dependencies = self.dependency_map.local_input_dependency_map.pop( main_connection2 ) - self.dependency_map._local_input_dependency_map.setdefault( + self.dependency_map.local_input_dependency_map.setdefault( main_connection1, old_dependencies ) for dependecy in old_dependencies: if ( dependecy - not in self.dependency_map._local_input_dependency_map[ + not in self.dependency_map.local_input_dependency_map[ main_connection1 ] ): - self.dependency_map._local_input_dependency_map[ + self.dependency_map.local_input_dependency_map[ main_connection1 ].append(dependecy) @@ -741,7 +741,7 @@ def extend( for key, value in kwargs.items(): # Check if given keys are among model's keys. - if key not in model._input_keys | model.conns.output_keys: + if key not in model.input_keys | model.conns.output_keys: raise KeyError( f"Given '{key}' key is not an input or output of " "the model '{model}'" @@ -749,11 +749,11 @@ def extend( # Check proper naming of given keys. if isinstance(value, str): - if key in model._input_keys and not value.isidentifier(): + if key in model.input_keys and not value.isidentifier(): raise KeyError( f"Given key name {value} is not a proper identifier string!" ) - elif key in model._input_keys: + elif key in model.input_keys: input_values.add(value) else: output_values.add(value) @@ -776,8 +776,8 @@ def extend( if value._shape is not None: shape_info |= {key: value._shape} - if value._type is not None: - type_info[key] = value._type + if value.type is not None: + type_info[key] = value.type elif isinstance(value, NullConnection): continue @@ -850,7 +850,7 @@ def extend( if isinstance(c_input := model.canonical_input, Connection): c_input_obj = self.conns.get_con_by_metadata(c_input.data.metadata) if c_input_obj is not None and c_input_obj.metadata.data.value is TBD: - if c_input_obj not in self.dependency_map._local_output_dependency_map: + if c_input_obj not in self.dependency_map.local_output_dependency_map: # Update canonical input with model canonical input if c_input_obj not in self.conns.input_connections: self._canonical_input = NOT_AVAILABLE @@ -860,7 +860,7 @@ def extend( elif ( self._canonical_input - in self.dependency_map._local_output_dependency_map + in self.dependency_map.local_output_dependency_map ): # Model canonical output used as input than make it None self._canonical_input = NOT_AVAILABLE @@ -868,7 +868,7 @@ def extend( if isinstance(c_output := model.canonical_output, Connection): c_output_obj = self.conns.get_con_by_metadata(c_output.data.metadata) - if c_output_obj not in self.dependency_map._local_input_dependency_map: + if c_output_obj not in self.dependency_map.local_input_dependency_map: # Update canonical output with model canonical output if c_output_obj is None: self._canonical_output = NOT_AVAILABLE @@ -876,8 +876,7 @@ def extend( self._canonical_output = c_output_obj elif ( - self._canonical_output - in self.dependency_map._local_input_dependency_map + self._canonical_output in self.dependency_map.local_input_dependency_map ): # Model canonical output used as input than make it None self._canonical_output = NOT_AVAILABLE @@ -892,23 +891,23 @@ def _extend(self, info: ExtendInfo | PrimitiveModel | Model) -> Self: # Call model with empty arguments if directly model is given. if isinstance(info, PrimitiveModel | Model): info = info() - model, kwargs = info._model, info._connections + model, kwargs = info.model, info.connections if ( - model._canonical_input is not NOT_AVAILABLE + model.canonical_input is not NOT_AVAILABLE and ( - model._canonical_input.key not in kwargs - or kwargs[model._canonical_input.key] is NOT_GIVEN + model.canonical_input.key not in kwargs + or kwargs[model.canonical_input.key] is NOT_GIVEN ) and len(self.dag) > 0 ): kwargs[model._canonical_input.key] = self.canonical_output for key, value in kwargs.items(): - _value = value._name if isinstance(value, IOKey) else value + _value = value.name if isinstance(value, IOKey) else value if isinstance(_value, str) and _value == "": - if key in model._input_keys: + if key in model.input_keys: _value = NOT_GIVEN else: raise KeyError( @@ -968,7 +967,7 @@ def _update_key_name( key_origin = key_prefix + key_origin return new_key, key_origin - def _generate_keys( + def generate_keys( self, symbolic: bool = True, include_internals: bool = True, @@ -982,8 +981,8 @@ def _generate_keys( input_set = set(self.external_keys) keys = "external_keys" else: - input_set = set(self._input_keys) - keys = "_input_keys" + input_set = set(self.input_keys) + keys = "input_keys" sorted_inputs = [ self.dag[m][key].key @@ -1000,7 +999,7 @@ def _generate_keys( if ( self._canonical_input is not NOT_AVAILABLE and key == self._canonical_input.key - and "input" not in self._input_keys + and "input" not in self.input_keys ): # Handle canonical input new_key = "input" @@ -1009,7 +1008,7 @@ def _generate_keys( assert key_origin is not None # Add prefix until key_origin not in underscored_keys and input_keys. while ( - key_origin in (underscored_keys | self._input_keys) + key_origin in (underscored_keys | self.input_keys) or key_origin == "input" ): key_origin = "_" + key_origin @@ -1026,25 +1025,25 @@ def _generate_keys( # (add index to initial key). raw_key = raw_keys[key_origin][0] key_mappings[raw_key] = key_mappings[raw_key] + "_0" - if key_mappings[raw_key] in self._input_keys: + if key_mappings[raw_key] in self.input_keys: new_key, key_origin = self._update_key_name( new_key, underscored_keys, raw_keys, key_mappings, key_origin, - set(self._input_keys), + set(self.input_keys), ) new_key = key_origin + key_suffix - if new_key in self._input_keys: + if new_key in self.input_keys: new_key, key_origin = self._update_key_name( new_key, underscored_keys, raw_keys, key_mappings, key_origin, - set(self._input_keys), + set(self.input_keys), ) raw_keys[key_origin].append(key) key_mappings[key] = new_key @@ -1137,7 +1136,9 @@ def _freeze(self) -> None: ): # self.output_keys += (self.canonical_output.key,) assert isinstance(self._canonical_output, ConnectionData) - self.conns._set_connection_type(self._canonical_output, KeyType.OUTPUT) + self.conns.set_connection_type( + self._canonical_output, KeyType.OUTPUT, safe=False + ) # setattr(self, self._canonical_output.key, self.canonical_output) self.dependency_map.update_all_keys() @@ -1148,7 +1149,7 @@ def _freeze(self) -> None: if m.name is None: m.name = model_names[m] - if self._formula_key is not None: + if self.formula_key is not None: # Must be convertable to primitive. assert len(self.conns.output_keys) == 1, ( "Logical models have altenative primitive implementation must " @@ -1229,12 +1230,12 @@ def extract_connection_info( data_memo: Mapping[int, Tensor | Scalar] | None = None, ): conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} - if self._input_keys: + if self.input_keys: if data_to_key_map is None: data_to_key_map = {} if data_memo is None: data_memo = {} - model_key_map = {} + model_key_map: dict[BaseModel, dict[str, str]] = {} # handle the case when model is constructed with += operation. In that case, # directly take canonical output as the output_key. @@ -1248,7 +1249,7 @@ def extract_connection_info( else self.conns.output_keys ) # extract key mappings and data map of outer model - key_mappings = self._generate_keys( + key_mappings = self.generate_keys( include_internals=False, include_outputs=True ) data_map = {key: conn.metadata.data for key, conn in self.conns.all.items()} @@ -1262,9 +1263,9 @@ def extract_connection_info( # set default structure of conn_info and shape_info conns = conn_info.setdefault(model_name, ({}, {})) # include input keys with Tensor value - input_keys = tuple(model._input_keys) + input_keys = tuple(model.input_keys) # Generate sub_model key_map and data map - model_key_map[model] = m_key_mappings = model._generate_keys( + model_key_map[model] = m_key_mappings = model.generate_keys( include_internals=False, include_outputs=True ) m_data_map = { @@ -1296,22 +1297,22 @@ def extract_connection_info( if (val := key_data.value) is not TBD: conn.append(str(val)) - elif outer_key in self._input_keys: + elif outer_key in self.input_keys: # If outer_key in input_keys of overall model, it means # the input key is overall input to the model. Do the # updates accordingly input_name = ["'" + key + "'" for key in updated_outer_key] conn.extend(input_name) else: - # if input_key is not in self._input_keys, that means this + # if input_key is not in self.input_keys, that means this # input key connected to a model and it is an internal # connection. Find the connected model and do the intializations - con_model = self.dependency_map._local_output_dependency_map[ + con_model = self.dependency_map.local_output_dependency_map[ outer_conn ][0] con_generated_keys = model_key_map.setdefault( con_model, - con_model._generate_keys( + con_model.generate_keys( include_internals=False, include_outputs=True ), ) @@ -1337,15 +1338,15 @@ def extract_connection_info( # Lastly, traverse through output keys of the overall model # Find the connected model, and find the inner key by finding # the metadata - metadata = self.conns._get_metadata(outer_key) + metadata = self.conns.get_metadata(outer_key) outer_out_conn = self.conns.get_connection(outer_key) assert metadata is not None, "Metadata is not found!" assert outer_out_conn is not None, "Connection is not found" - model = self.dependency_map._local_output_dependency_map[ - outer_out_conn - ][0] + model = self.dependency_map.local_output_dependency_map[outer_out_conn][ + 0 + ] other_conn = model.conns.get_con_by_metadata(metadata) assert other_conn is not None, "Connection is not found" diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 9d12df5e..7ff2a0ff 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -56,7 +56,7 @@ def __init__( name: str | None = None, **kwargs: IOKey | Tensor | Scalar, ) -> None: - self._formula_key = formula_key + self.formula_key = formula_key self.grad_formula = formula_key + "_grad" super().__init__(name=name) @@ -73,9 +73,9 @@ def __init__( for key, value in kwargs.items(): # TODO: The first if block is temporary. All if else blocks will be # removed after the implementation of the new type system. - if get_origin(value._type) is Union: - args = get_args(value._type) - types = [] + if get_origin(value.type) is Union: + args = get_args(value.type) + types: list[type] = [] for _type in args: # TODO: assertion will be removed, # we should allow Scalar|Tensor type simultaneously. @@ -87,26 +87,26 @@ def __init__( _value: Tensor | Scalar = Tensor( shape=shapes[key].node, possible_types=possible_types, - value=value._value, # type: ignore - interval=value._interval, + value=value.value, # type: ignore + interval=value.interval, ) assert isinstance(_value, Tensor) data_set.add(_value) - elif is_mytensor_type(value._type): + elif is_mytensor_type(value.type): assert isinstance(value, IOKey) _value = Tensor( shape=shapes[key].node, - possible_types=get_mytensor_subtype(value._type), # type: ignore - value=value._value, # type: ignore - interval=value._interval, + possible_types=get_mytensor_subtype(value.type), # type: ignore + value=value.value, # type: ignore + interval=value.interval, ) data_set.add(_value) elif isinstance(value, Tensor | Scalar): _value = value else: _value = Scalar( - possible_types=value._type, # type: ignore - value=value._value, # type: ignore + possible_types=value.type, # type: ignore + value=value.value, # type: ignore ) conn_data = self.create_connection(IOHyperEdge(_value), key) @@ -118,7 +118,7 @@ def __init__( self.conns.set_connection_type(conn_data, KeyType.INPUT) is_diff |= not _value.is_non_diff if isinstance(output_data, Tensor): - output_data._differentiable = is_diff + output_data.differentiable = is_diff # Initially run all given tensors' constraints self.constraint_solver.update_shapes(Updates(data_set)) @@ -129,20 +129,20 @@ def __init__( output_conns = OrderedSet({out_conn}) for conn in self.conns.input_connections: - self.dependency_map._local_input_dependency_map[conn] = [ + self.dependency_map.local_input_dependency_map[conn] = [ (self, output_conns) ] for conn in output_conns: - self.dependency_map._local_output_dependency_map[conn] = (self, input_conns) + self.dependency_map.local_output_dependency_map[conn] = (self, input_conns) - self.dependency_map._cache_internal_references(out_conn, input_conns) + self.dependency_map.cache_internal_references(out_conn, input_conns) self.dependency_map.update_all_keys() # Link canonicals - if isinstance(self.canonical_input, NotAvailable) and len(self._input_keys) > 0: + if isinstance(self.canonical_input, NotAvailable) and len(self.input_keys) > 0: canonical_input_key = ( - "input" if "input" in self._input_keys else next(iter(self._input_keys)) + "input" if "input" in self.input_keys else next(iter(self.input_keys)) ) canonical_input_conn = self.conns.get_connection(canonical_input_key) if canonical_input_conn is None: @@ -189,7 +189,7 @@ def extract_connection_info( conns: tuple[dict[str, list[str]], dict[str, list[str]]] = ({}, {}) # Take the input_keys with tensor values - input_keys = tuple(self._input_keys) + input_keys = tuple(self.input_keys) for key in tuple(input_keys) + tuple(self.conns.output_keys): # find data of the key. diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index 45547385..ee1f9593 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -89,9 +89,9 @@ def is_scalar_type(t: Any) -> TypeGuard[MainValueType]: def remove_keys_from_store(self, keys: set[str]): keys -= set(self.graph.output_keys) for key in keys: - self._remove_key_from_store(key, label_as_unused=False, hard_remove=True) + self.remove_key_from_store(key, label_as_unused=False, hard_remove=True) - def _remove_key_from_store( + def remove_key_from_store( self, key: str, label_as_unused: bool = True, hard_remove: bool = False ): if key in self.data_values: @@ -171,7 +171,7 @@ def _infer_unused_keys(self, key: str): if source_key not in output_keys and set( self.graph.get_target_keys(source_key, True) ).issubset(self._unused_keys | self.cached_data.keys()): - self._remove_key_from_store(source_key) + self.remove_key_from_store(source_key) queue |= set( self.graph.get_source_keys(source_key, True) @@ -347,10 +347,10 @@ def infer_static_keys(self) -> Updates: static_value: DataType | MainValueType - fn = fn_dict[model._formula_key] + fn = fn_dict[model.formula_key] # Orginize args and kwargs - local_input_keys = list(model._input_keys) + local_input_keys = list(model.input_keys) if self.backend.is_manualgrad: local_input_keys.append("cache") inputs = { @@ -374,10 +374,10 @@ def infer_static_keys(self) -> Updates: } # If function needs backend specific args - if model._formula_key in self.backend.array_creation_funcs: + if model.formula_key in self.backend.array_creation_funcs: kwargs["precision"] = self.backend.precision if not self.backend.is_manualgrad: - kwargs["device"] = self.backend._device + kwargs["device"] = self.backend.get_device() static_value = fn(*args, **kwargs) diff --git a/mithril/framework/physical/flat_graph.py b/mithril/framework/physical/flat_graph.py index 4df662b0..0a1291c4 100644 --- a/mithril/framework/physical/flat_graph.py +++ b/mithril/framework/physical/flat_graph.py @@ -386,7 +386,7 @@ def _is_duplicate( else: model_id.append(conn.key) - final_model_id = "-".join(model_id) + f"-{node.model._formula_key}" + final_model_id = "-".join(model_id) + f"-{node.model.formula_key}" if final_model_id in self.unique_model_table: return self.unique_model_table[final_model_id] diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index d7e8340e..3fc09918 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -42,9 +42,9 @@ UniadicRecord, Updates, Variadic, - _get_shapes, _ShapesType, create_shape_map, + get_shapes, get_summary, get_summary_shapes, get_summary_types, @@ -99,7 +99,7 @@ def __init__( extend_info = model() model_keys = {} for key in model.external_keys: - value = extend_info._connections.get(key, NOT_GIVEN) + value = extend_info.connections.get(key, NOT_GIVEN) # NOTE: Do not set default value if it is given in constant_keys. value = (value, NOT_GIVEN)[key in constant_keys] default_val = model.conns.get_data(key).value @@ -125,7 +125,7 @@ def __init__( # NOTE: Reconsider updating logical dag in order. self._input_keys: set[str] = { - flat_model.external_mapping[key] for key in model._input_keys + flat_model.external_mapping[key] for key in model.input_keys } # Add canonical output mapping to key_mappings if necessary @@ -174,12 +174,12 @@ def __init__( self.inference = inference # Initialize flat graph and data store. - self._flat_graph: FlatGraph[DataType] = FlatGraph( + self.flat_graph: FlatGraph[DataType] = FlatGraph( self._input_keys, self._output_keys ) memo: dict[int, Tensor | Scalar] = {} self.data_store: StaticDataStore[DataType] = StaticDataStore( - self._flat_graph, backend, inference, model.constraint_solver, memo + self.flat_graph, backend, inference, model.constraint_solver, memo ) for p_model, mappings in flat_model: @@ -205,9 +205,9 @@ def __init__( or physical_data.value is not TBD ): # TODO: Create an API for setting differentiability of a tensor. - physical_data._differentiable = False + physical_data.differentiable = False elif global_key in self._trainable_tensor_inputs: - physical_data._differentiable = True + physical_data.differentiable = True model_data[key] = physical_data self.data_store.data_memo[id(logical_data)] = physical_data @@ -237,7 +237,7 @@ def __init__( cache_scalar = Scalar(dict | None, cache_value) self.data_store.update_data({cache_name: cache_scalar}) - self._flat_graph.add_value(p_model, mappings) + self.flat_graph.add_value(p_model, mappings) for cached_key in list(self.data_store.cached_data.keys()): self.data_store._infer_unused_keys(cached_key) @@ -254,7 +254,7 @@ def __init__( # runtime must be manually named in logical model. if safe_names: runtime_data_keys = self.data_store.runtime_static_keys - unnamed_inputs = model._input_keys - self._input_keys - self.discarded_keys + unnamed_inputs = model.input_keys - self._input_keys - self.discarded_keys unnamed_data_keys = sorted( [ key @@ -380,12 +380,12 @@ def get_shapes( key: self.data_store.data_memo[id(value.metadata.data)] for key, value in model.conns.all.items() } - key_mappings = model._generate_keys(include_outputs=True) + key_mappings = model.generate_keys(include_outputs=True) else: data_dict = self.data key_mappings = None - return _get_shapes( + return get_shapes( data_dict=data_dict, uniadic_keys=uni_keys, varadic_keys=var_keys, @@ -406,6 +406,10 @@ def shapes(self) -> _ShapesType: def output_keys(self): return sorted(self._output_keys) + @property + def input_keys(self): + return self._input_keys + def _infer_differentiability(self, model: PrimitiveModel, dag: dict[str, str]): # Infer output differentiability only for the models # that have a Tensor type output. @@ -418,11 +422,11 @@ def _infer_differentiability(self, model: PrimitiveModel, dag: dict[str, str]): key != PrimitiveModel.output_key and not self.data[value].is_non_diff ): - self.data[output_key]._differentiable = True + self.data[output_key].differentiable = True return # If all inputs are non-differentiable, then the output is also # non-differentiable. - self.data[output_key]._differentiable = False + self.data[output_key].differentiable = False def randomize_params( self, @@ -512,17 +516,17 @@ def _pre_compile( self.jacobian_keys = jacobian_keys self.ignore_grad_keys: set[str] = set() - for node in self._flat_graph.nodes.values(): + for node in self.flat_graph.nodes.values(): conn_data = node.model.conns.get_connection("output") assert conn_data is not None if isinstance(conn_data.metadata.data, Scalar) or ( - not find_intersection_type(float, conn_data.metadata.data._type) + not find_intersection_type(float, conn_data.metadata.data.type) ): self.ignore_grad_keys.add( node.connections[PrimitiveModel.output_key].key ) - pruned_keys = self._flat_graph.prune_duplicate_nodes(self.data, constant_keys) + pruned_keys = self.flat_graph.prune_duplicate_nodes(self.data, constant_keys) updates = Updates() @@ -562,7 +566,7 @@ def _pre_compile( # of the model nor an input to a PrimitiveModel. self.discarded_keys |= { - key for key in self._flat_graph.hanging_keys if key not in self.output_keys + key for key in self.flat_graph.hanging_keys if key not in self.output_keys } self.discarded_keys, self._output_keys = self.infer_ignore( @@ -717,7 +721,7 @@ def infer_ignore( non_leaf_keys = { key for key in weak_keys - if key in self._flat_graph.all_source_keys and key in output_keys + if key in self.flat_graph.all_source_keys and key in output_keys } # Internal keys will be removed from output_keys but also they will # be removed from current ignored keys. @@ -729,13 +733,13 @@ def infer_ignore( key = queue.pop() # try forward inference (check if any inference is possible # from inputs to outputs) - self._flat_graph.infer_ignore_step(key, keys, queue, from_source=True) + self.flat_graph.infer_ignore_step(key, keys, queue, from_source=True) # try bacward inference (check if any inference possible # from outputs to inputs) - self._flat_graph.infer_ignore_step(key, keys, queue, from_source=False) + self.flat_graph.infer_ignore_step(key, keys, queue, from_source=False) if update_graph: - self._flat_graph.remove_key(key) + self.flat_graph.remove_key(key) output_keys.discard(key) self._input_keys.discard(key) @@ -760,11 +764,11 @@ def _calculate_parameters( - self.data_store.runtime_static_keys ) for model, model_name in name_mappings.items(): - key_mappings = model._generate_keys(include_outputs=True) + key_mappings = model.generate_keys(include_outputs=True) for key in model.external_keys: in_dict, out_dict = param_info.setdefault(model_name, ({}, {})) inner_key = key_mappings.get(key, key) - if key not in model._input_keys: + if key not in model.input_keys: # case where the key is not an input key (hence not a trainable) out_dict[inner_key] = "0" continue @@ -918,10 +922,10 @@ def summary( ) else: # Remove unused models and cached models - all_models = list(self._flat_graph.get_models()) + all_models = list(self.flat_graph.get_models()) for key in self.data_store.unused_keys | self.data_store.cached_data.keys(): if ( - unused_model := self._flat_graph.connections.get(key) + unused_model := self.flat_graph.connections.get(key) ) is not None and unused_model.node is not None: all_models.remove(unused_model.node.model) @@ -985,13 +989,13 @@ def extract_connection_info( self, name_mappings: dict[PrimitiveModel, str] | None = None ): if name_mappings is None: - name_mappings = define_unique_names(self._flat_graph.get_models()) + name_mappings = define_unique_names(self.flat_graph.get_models()) conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} for model, model_name in name_mappings.items(): conn_info.setdefault(model_name, ({}, {})) - model_node = self._flat_graph.nodes[model] - input_keys = tuple(model._input_keys) + model_node = self.flat_graph.nodes[model] + input_keys = tuple(model.input_keys) for input_key in input_keys: connection = model_node.connections[input_key] @@ -1032,8 +1036,8 @@ def extract_connection_info( for output_key in self.output_keys: # Traverse output_keys of overall model and make indications accordingly - outer_key = self._flat_graph.output_dict.get(output_key, output_key) - output_connection = self._flat_graph.connections[outer_key] + outer_key = self.flat_graph.output_dict.get(output_key, output_key) + output_connection = self.flat_graph.connections[outer_key] assert output_connection.node is not None model = output_connection.node.model model_name = name_mappings[model] @@ -1046,8 +1050,8 @@ def extract_connection_info( def _replace_with_primitive( self, model: Model, key_mappings: dict[str, str] ) -> tuple[PrimitiveModel, dict[str, str]]: - assert model._formula_key is not None - formula = self.backend.primitive_function_dict[model._formula_key] + assert model.formula_key is not None + formula = self.backend.primitive_function_dict[model.formula_key] primitive_input_keys = formula.__code__.co_varnames[ : formula.__code__.co_argcount ] # make function? @@ -1055,15 +1059,15 @@ def _replace_with_primitive( # Remove unnecessary keys unnecessary_keys = { key: key_mappings.get(key, key) - for key in (set(model._input_keys) - set(primitive_input_keys)) + for key in (set(model.input_keys) - set(primitive_input_keys)) } - input_keys = list(model._input_keys) + input_keys = list(model.input_keys) external_keys = list(model.external_keys) for key, val in unnecessary_keys.items(): # self.static_keys.pop(val) # self.non_differentiables.pop(val) - self.data_store._remove_key_from_store(val, label_as_unused=False) + self.data_store.remove_key_from_store(val, label_as_unused=False) self.data.pop(val) self._input_keys.discard(val) input_keys.remove(key) @@ -1076,7 +1080,7 @@ def _replace_with_primitive( kwargs = {key: model.conns.all[key].metadata.data for key in external_keys} primitive = PrimitiveModel( - formula_key=model._formula_key, name=model.name, **kwargs + formula_key=model.formula_key, name=model.name, **kwargs ) primitive.parent = model.parent @@ -1095,7 +1099,7 @@ def evaluate( ) -> DataEvalType[DataType]: if ( isinstance(self.backend, ParallelBackend) - and self.backend._parallel_manager is not None + and self.backend.get_parallel_manager() is not None ): return self.backend._run_callable(params, data, fn_name="eval_fn") else: @@ -1113,7 +1117,7 @@ def evaluate_gradients( ) if ( isinstance(self.backend, ParallelBackend) - and self.backend._parallel_manager is not None + and self.backend.get_parallel_manager() is not None ): return self.backend._run_callable( params, data, output_gradients, fn_name="eval_grad_fn" @@ -1133,7 +1137,7 @@ def evaluate_all( ) if ( isinstance(self.backend, ParallelBackend) - and self.backend._parallel_manager is not None + and self.backend.get_parallel_manager() is not None ): return self.backend._run_callable( params, data, output_gradients, fn_name="eval_all_fn" @@ -1190,7 +1194,7 @@ def __init__( self.short_namings = short_namings self._name_externals() - self._generate_keys(model) + self.generate_keys(model) self._rebase_names() @property @@ -1277,7 +1281,7 @@ def _name_externals(self): self._external_mapping[base_name_str] = name self.assigned_edges[conn.metadata] = name - if key in self.model._input_keys: + if key in self.model.input_keys: self.used_edges.add(conn.metadata) self.external_edges[conn.metadata] = base_name_str @@ -1320,7 +1324,7 @@ def _get_unique_name_str( return self._get_next_unique_name(base_name) return base_name - def _generate_keys( + def generate_keys( self, model: BaseModel, mappings: dict[str, str] | None = None, @@ -1415,7 +1419,7 @@ def _process_model(self, model: Model, mappings: dict[str, str], parent_name: st else: name_mapping[key] = mappings[conn.key] - self._generate_keys(m, name_mapping, parent_name=name) + self.generate_keys(m, name_mapping, parent_name=name) def _check_for_queue(self, hyperedge: IOHyperEdge): if hyperedge in self.queued_models: diff --git a/mithril/models/models.py b/mithril/models/models.py index 1a2e638b..4733354f 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -26,6 +26,7 @@ Connection, ConnectionType, IOKey, + MainValueType, ShapeTemplateType, TensorValueType, ToBeDetermined, @@ -587,9 +588,9 @@ def __call__( # type: ignore[override] ) -> ExtendInfo: kwargs = {"input": input, "weight": weight, "output": output} - if "bias" not in self._input_keys and bias != NOT_GIVEN: + if "bias" not in self.input_keys and bias != NOT_GIVEN: raise KeyError("bias is not a valid input when 'use_bias' is False!") - elif "bias" in self._input_keys: + elif "bias" in self.input_keys: kwargs["bias"] = bias return super().__call__(**kwargs) @@ -740,14 +741,14 @@ def __call__( # type: ignore[override] ) -> ExtendInfo: kwargs = {"input": input, "output": output} - if "weight" not in self._input_keys and weight != NOT_GIVEN: + if "weight" not in self.input_keys and weight != NOT_GIVEN: raise KeyError("weight is not a valid input when 'use_scale' is False!") - elif "weight" in self._input_keys: + elif "weight" in self.input_keys: kwargs["weight"] = weight - if "bias" not in self._input_keys and bias != NOT_GIVEN: + if "bias" not in self.input_keys and bias != NOT_GIVEN: raise KeyError("bias is not a valid input when 'use_bias' is False!") - elif "bias" in self._input_keys: + elif "bias" in self.input_keys: kwargs["bias"] = bias return super().__call__(**kwargs) @@ -815,14 +816,14 @@ def __call__( # type: ignore[override] ) -> ExtendInfo: kwargs = {"input": input, "output": output} - if "weight" not in self._input_keys and weight != NOT_GIVEN: + if "weight" not in self.input_keys and weight != NOT_GIVEN: raise KeyError("weight is not a valid input when 'use_scale' is False!") - elif "weight" in self._input_keys: + elif "weight" in self.input_keys: kwargs["weight"] = weight - if "bias" not in self._input_keys and bias != NOT_GIVEN: + if "bias" not in self.input_keys and bias != NOT_GIVEN: raise KeyError("bias is not a valid input when 'use_bias' is False!") - elif "bias" in self._input_keys: + elif "bias" in self.input_keys: kwargs["bias"] = bias return super().__call__(**kwargs) @@ -1079,7 +1080,7 @@ def __init__( weight: TensorValueType | ToBeDetermined = TBD, bias: TensorValueType | ToBeDetermined = TBD, ) -> None: - if len(kernel._input_keys) < 2: + if len(kernel.input_keys) < 2: raise KeyError("Kernel requires at least two inputs!") if len(kernel.conns.output_keys) != 1: raise KeyError("Kernel requires single output!") @@ -1096,7 +1097,7 @@ def __init__( # Get kernel inputs from given model. kernel_input_args = { key: key - for key in kernel._input_keys + for key in kernel.input_keys if not kernel.conns.is_key_non_diff(key) } (kernel_output_name,) = kernel.conns.output_keys # NOTE: Assumes single output! @@ -1514,7 +1515,7 @@ def __init__( self += slice_model_2(input="prev_hidden", stop=scalar_item.output) body_kwargs: dict[str, ConnectionType] = { - key: key for key in cell_body._input_keys if key[0] != "$" + key: key for key in cell_body.input_keys if key[0] != "$" } body_kwargs["prev_cell"] = slice_model_1.output body_kwargs["prev_hidden"] = slice_model_2.output @@ -1748,7 +1749,12 @@ def __call__( # type: ignore[override] class RNN(Model): - def __init__(self, cell_type: Cell, name: str | None = None, **kwargs) -> None: + def __init__( + self, + cell_type: Cell, + name: str | None = None, + **kwargs: TensorValueType | MainValueType, + ) -> None: self.cell_type = cell_type super().__init__(name=name) self.factory_inputs = kwargs @@ -1767,7 +1773,7 @@ def __init__( teacher_forcing: bool = False, name: str | None = None, input: TensorValueType | ToBeDetermined = TBD, - **kwargs, + **kwargs: TensorValueType | MainValueType, ) -> None: super().__init__(cell_type=cell_type, name=name) self.factory_inputs = {"input": input, **kwargs} @@ -1843,7 +1849,7 @@ def __init__( max_sequence_length: int, name: str | None = None, input: TensorValueType | ToBeDetermined = TBD, - **kwargs, + **kwargs: TensorValueType | ToBeDetermined, ) -> None: super().__init__(cell_type=cell_type, name=name) self.factory_inputs = {"input": input, **kwargs} @@ -1883,7 +1889,7 @@ def __init__( prev_cell = current_cell self._freeze() - def __call__( # type: ignore[override] + def __call__( self, input: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super(RNN, self).__call__(input=input, **model_keys) @@ -1898,7 +1904,7 @@ def __init__( max_sequence_length: int, name: str | None = None, hidden_concat: TensorValueType | ToBeDetermined = TBD, - **kwargs, + **kwargs: TensorValueType | ToBeDetermined, ) -> None: super().__init__(cell_type, name=name) self.factory_inputs = {"hidden_concat": hidden_concat, **kwargs} @@ -1972,7 +1978,7 @@ def __init__( teacher_forcing: bool = False, name: str | None = None, indices: TensorValueType | ToBeDetermined = TBD, - **kwargs, + **kwargs: TensorValueType | ToBeDetermined, ) -> None: super().__init__(name=name) self.factory_inputs = {"indices": indices, **kwargs} @@ -1991,11 +1997,11 @@ def __init__( permutation_model = PermuteTensor() - enc_input_mapping = {key: key for key in encoder._input_keys if "$" not in key} + enc_input_mapping = {key: key for key in encoder.input_keys if "$" not in key} dec_input_mapping = { key: "decoder_" + key if "target" not in key else key - for key in decoder._input_keys + for key in decoder.input_keys if "$" not in key and key != "initial_hidden" } @@ -2010,7 +2016,7 @@ def __init__( self._freeze() - def __call__( # type: ignore[override] + def __call__( self, indices: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super().__call__(indices=indices, **model_keys) @@ -2026,7 +2032,7 @@ def __init__( max_target_sequence_length: int, name: str | None = None, indices: TensorValueType | ToBeDetermined = TBD, - **kwargs, + **kwargs: TensorValueType | ToBeDetermined, ) -> None: super().__init__(name=name) self.factory_inputs = {"indices": indices, **kwargs} @@ -2041,11 +2047,11 @@ def __init__( cell_type=cell_type, max_sequence_length=max_target_sequence_length ) - enc_input_mapping = {key: key for key in encoder._input_keys if "$" not in key} + enc_input_mapping = {key: key for key in encoder.input_keys if "$" not in key} dec_input_mapping = { key: "decoder_" + key if "target" not in key else key - for key in decoder._input_keys + for key in decoder.input_keys if "$" not in key and key != "initial_hidden" } @@ -2078,7 +2084,7 @@ def __init__( name: str | None = None, input1: TensorValueType | ToBeDetermined = TBD, input2: TensorValueType | ToBeDetermined = TBD, - **kwargs, + **kwargs: TensorValueType | ToBeDetermined, ) -> None: super().__init__(name=name) self.factory_inputs = {"input1": input1, "input2": input2, **kwargs} @@ -2373,7 +2379,7 @@ def __call__( # type: ignore[override] "output": output, } - if "p_joint" in self._input_keys: + if "p_joint" in self.input_keys: kwargs["p_joint"] = p_joint elif p_joint != NOT_GIVEN: raise ValueError("p_joint is only required when calculate_p_joint is True!") @@ -2432,7 +2438,7 @@ def __init__( if base_model.requires_norm: base_kwargs["norm"] = "norm" - for key in base_model._input_keys: + for key in base_model.input_keys: con = base_model.conns.get_connection(key) assert con is not None if key not in base_kwargs and not con.is_key_autogenerated: @@ -2484,7 +2490,7 @@ def __call__( # type: ignore[override] "predicted_coords": predicted_coords, "output": output, } - if "input" in self._input_keys: + if "input" in self.input_keys: kwargs["input"] = input elif coords != NOT_GIVEN: raise ValueError("coords is only required when input_type is 'points'!") @@ -2537,7 +2543,7 @@ def __call__( # type: ignore[override] "output": output, } - if "coords" in self._input_keys: + if "coords" in self.input_keys: kwargs["coords"] = coords elif coords != NOT_GIVEN: raise ValueError("coords is only required when input_type is 'points'!") diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 9aaf7277..fbc353c6 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -391,7 +391,7 @@ def __call__( # type: ignore[override] "output": output, } # Check if the given argument set is valid. - if self._formula_key == "cross_entropy_with_log_probs": + if self.formula_key == "cross_entropy_with_log_probs": args: list[str] = [] if robust is not False: args.append("robust") @@ -487,9 +487,7 @@ def __init__( ) pos_weight = True - pos_weight_type = ( - float | bool if pos_weight in (..., None) else type(pos_weight) - ) + pos_weight_type = type(pos_weight) kwargs: dict[str, IOKey] = { "output": IOKey(shape=[("Var_out", ...)], type=MyTensor[float]), "input": IOKey(shape=[("Var_out", ...)], type=GenericTensorType), @@ -896,7 +894,7 @@ def __init__( # self.factory_inputs = {key: value for key, value in kwargs.items()} self.factory_inputs = kwargs # type: ignore - input_keys = [key for key in self._input_keys if key != "axis"] + input_keys = [key for key in self.input_keys if key != "axis"] self._set_constraint( fn=concat_constraints, keys=["output"] + ["axis"] + input_keys ) @@ -1049,10 +1047,10 @@ def __call__( # type: ignore[override] "output": output, } - if "bias" not in self._input_keys and bias != NOT_GIVEN: + if "bias" not in self.input_keys and bias != NOT_GIVEN: raise ValueError(f"Model does not have 'bias' input. \ Got {bias} as bias argument!") - elif "bias" in self._input_keys: + elif "bias" in self.input_keys: kwargs |= {"bias": bias} return super().__call__(**kwargs) @@ -1148,11 +1146,11 @@ def __call__( # type: ignore[override] "output": output, } - if "bias" not in self._input_keys and bias != NOT_GIVEN: + if "bias" not in self.input_keys and bias != NOT_GIVEN: raise ValueError( f"Model does not have 'bias' input. Got {bias} as bias argument!" ) - elif "bias" in self._input_keys: + elif "bias" in self.input_keys: kwargs |= {"bias": bias} return super().__call__(**kwargs) @@ -2144,7 +2142,7 @@ def __call__( # type: ignore[override] and attn_mask is not NOT_GIVEN and not isinstance(attn_mask, str) and isinstance(attn_mask, IOKey) - and attn_mask._value is not None # TODO: Here will be updated! + and attn_mask.value is not None # TODO: Here will be updated! ): raise KeyError( "Model does not have 'attn_mask' input. Got attn_mask argument!" diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index 27e630e2..deab9392 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -15,7 +15,7 @@ import re from collections.abc import Callable from copy import deepcopy -from typing import Any +from typing import Any, Self from ..framework import ( NOT_GIVEN, @@ -30,7 +30,7 @@ Model, UniadicRecord, Variadic, - _get_shapes, + get_shapes, get_summary_shapes, ) from ..framework.common import TBD, NotAvailable, Table @@ -71,12 +71,12 @@ def __init__(self, model: BaseModel) -> None: self._is_finalized = False self.factory_args = {"model": model} # TODO: If we add inputs as IOKey, we get multi-write error. Fix this. - key_mappings = model._generate_keys(symbolic=False, include_internals=True) + key_mappings = model.generate_keys(symbolic=False, include_internals=True) extend_kwargs = { key: key_mappings.get( key, IOKey(name=key) if key in model.conns.output_keys else key ) - for key in model._input_keys | model.conns.output_keys + for key in model.input_keys | model.conns.output_keys } if LossKey in extend_kwargs: @@ -98,7 +98,7 @@ def __init__(self, model: BaseModel) -> None: self.geomean_map: dict[str, list[tuple[Connection, float]]] = {} self.reduce_inputs: dict[str, list[tuple[Connection, Connection]]] = {} - def __add__(self, model: ExtendInfo | PrimitiveModel | Model): + def __add__(self, model: ExtendInfo | PrimitiveModel | Model) -> Self: """This function allows models to be added sequentially via "+=" operator. There are several conditions for a model to be sequentially added: if added model has single input, connect that input directly. @@ -153,8 +153,8 @@ def add_loss( # If provided key namings does not match with Loss model if { key - for key, value in loss_model(**kwargs)._connections.items() - if value is NOT_GIVEN and key in loss_model._input_keys + for key, value in loss_model(**kwargs).connections.items() + if value is NOT_GIVEN and key in loss_model.input_keys } - loss_model.conns.get_non_diff_keys(): # if set(kwargs.keys()) != keys: raise KeyError("The provided keys do not match the model's loss.") @@ -227,7 +227,7 @@ def add_loss( raise KeyError("Given key does not belong to the Model!") loss_key = prev_con.key for i, m in enumerate(reduce_steps): - in_key = m._canonical_input.key + in_key = m.canonical_input.key if i == len(reduce_steps) - 1 and key_name is not None and coef is None: out_key = self.get_single_output(m).key # self.extend(m, **{in_key: prev_out_key.conn, out_key: key_name}) @@ -274,7 +274,7 @@ def add_regularization( key_name: str | None = None, **kwargs: Any, ): - keys = set(model._input_keys) - model.conns.get_non_diff_keys() + keys = set(model.input_keys) - model.conns.get_non_diff_keys() if set(kwargs.keys()) != keys: raise KeyError( "The provided keys do not match the regularization model keys!" @@ -319,7 +319,7 @@ def _add_regularization( case Connection(): reg_str = reg_key.data.key case None: - reg_str = model._canonical_input.key + reg_str = model.canonical_input.key if any([isinstance(value, re.Pattern) for value in kwargs.values()]): if len(kwargs) > 1: raise Exception( @@ -328,7 +328,7 @@ def _add_regularization( ) else: (regex,) = kwargs.values() - for key in tuple(self._input_keys): + for key in tuple(self.input_keys): if re.search(regex, key): self._add_regularization( deepcopy(model), @@ -337,11 +337,11 @@ def _add_regularization( **{reg_str: key}, ) else: - generated_keys = self._generate_keys(symbolic=False) + generated_keys = self.generate_keys(symbolic=False) non_diff_keys = { generated_keys.get(key, key) for key in self.conns.get_non_diff_keys() } - input_keys = {key for key in self._input_keys if "$" not in key} + input_keys = {key for key in self.input_keys if "$" not in key} trainable_keys = (input_keys | set(generated_keys.values())) - non_diff_keys trainables: set[IOHyperEdge] = set() for key in trainable_keys: @@ -404,7 +404,7 @@ def add_metric( prev_out_key = self.get_single_output(model) for i, m in enumerate(reduce_steps): - in_key = m._canonical_input.key + in_key = m.canonical_input.key if i == len(reduce_steps) - 1 and key_name is not None: out = self.get_single_output(m).data # self.extend(m, **{in_key: prev_out_key, out.key: key_name}) @@ -436,7 +436,7 @@ def _add_loss_combiner(self): concat_model = Concat(n=num_of_loss_keys, axis=None) concat_kwargs: dict[Any, Any] = {} idx = 0 - for key in concat_model._input_keys: + for key in concat_model.input_keys: # if not concat_model.connections[key].metadata.value.is_non_diff: if not concat_model.conns.is_key_non_diff(key): concat_kwargs[key] = self.conns.all[ @@ -464,7 +464,7 @@ def _add_loss_combiner(self): output=IOKey(name=loss_output_key), ) - def _finalize(self): + def finalize(self): # Apply finalization steps if and only if not finalized before. if not self._is_finalized: self._add_geo_mean() @@ -520,7 +520,7 @@ def summary( # TODO: Check the way we provide "depth" argument # to the model.summary() method. - summary_kwargs = { + summary_kwargs: dict[str, Any] = { "shapes": shapes, "types": types, "symbolic": symbolic, @@ -538,7 +538,7 @@ def summary( model_shapes = {} for sub_model, sub_model_name in name_mappings.items(): - model_shapes[sub_model_name] = _get_shapes( + model_shapes[sub_model_name] = get_shapes( data_dict={ key: value.metadata.data for key, value in sub_model.conns.all.items() @@ -548,7 +548,7 @@ def summary( varadic_keys=var_cache, symbolic=symbolic, verbose=False, - key_mappings=sub_model._generate_keys( + key_mappings=sub_model.generate_keys( include_internals=False, include_outputs=True ), ) @@ -566,10 +566,10 @@ def summary( ) self._model.get_shapes(uni_cache, var_cache, symbolic, verbose=False) for loss_key, loss_dict in zip(self.loss_keys, self._losses, strict=False): - t_list = [] + t_list: list[list[str]] = [] loss_conn = self.conns.get_connection(loss_key) assert loss_conn is not None - model = self.dependency_map._local_output_dependency_map[loss_conn][0] + model = self.dependency_map.local_output_dependency_map[loss_conn][0] t_list.append([model.__class__.__name__]) m_name = name_mappings[model] conns = conn_info[m_name][0] @@ -611,15 +611,15 @@ def summary( assert conn.metadata is not None conn_data = self.conns.get_con_by_metadata(conn.metadata) assert conn_data is not None - model = self.dependency_map._local_output_dependency_map[conn_data][ + model = self.dependency_map.local_output_dependency_map[conn_data][ 0 ] r_list.append([model.__class__.__name__]) m_name = name_mappings[model] conns = conn_info[m_name][0] shape = shape_info[m_name][0] - reg_key = model._canonical_input.key - updated_reg_key = model._generate_keys(include_outputs=True).get( + reg_key = model.canonical_input.key + updated_reg_key = model.generate_keys(include_outputs=True).get( reg_key, reg_key ) r_list.append(conns[updated_reg_key]) @@ -643,7 +643,7 @@ def summary( m_list: list[list[str]] = [] m_conn = self.conns.get_connection(m_key) assert m_conn is not None - model = self.dependency_map._local_output_dependency_map[m_conn][0] + model = self.dependency_map.local_output_dependency_map[m_conn][0] m_list.append([model.__class__.__name__]) m_name = name_mappings[model] conns = conn_info[m_name][0] @@ -686,8 +686,9 @@ def _add_geo_mean(self): final_output = final_outputs[0] if (n_final_outputs := len(final_outputs)) > 0: concat_model = Concat(n=n_final_outputs, axis=None) - concat_kwargs, idx = {}, 0 - for key in concat_model._input_keys: + concat_kwargs: dict[str, int | Connection] = {} + idx = 0 + for key in concat_model.input_keys: if not concat_model.conns.is_key_non_diff(key): concat_kwargs[key] = final_outputs[idx] idx += 1 @@ -727,8 +728,9 @@ def _add_reduce_sizes(self, reduce_list: list[tuple[Connection, Connection]]): if (num_of_sizes := len(sizes)) > 0: concat_model = Concat(n=num_of_sizes, axis=None) - concat_kwargs, idx = {}, 0 - for key in concat_model._input_keys: + concat_kwargs: dict[str, int | Connection] = {} + idx = 0 + for key in concat_model.input_keys: if not concat_model.conns.is_key_non_diff(key): concat_kwargs[key] = sizes[idx] idx += 1 diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 8cb17f93..8dac5a40 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -139,7 +139,7 @@ def dict_to_model(modelparams: dict[str, Any]) -> BaseModel: submodels_dict[m_key] = m mappings: dict[str, IOKey | float | int | list | tuple | str] = {} for k, conn in connections[m_key].items(): - if conn in unnamed_keys and k in m._input_keys: + if conn in unnamed_keys and k in m.input_keys: continue if isinstance(conn, str | float | int | tuple | list): @@ -214,7 +214,7 @@ def model_to_dict(model: BaseModel) -> dict: for key, con in model.conns.all.items(): data = con.metadata.data if isinstance(data, Tensor) and not con.is_key_autogenerated: - model_dict["differentiability_info"][key] = data._differentiable + model_dict["differentiability_info"][key] = data.differentiable for shape in model.assigned_shapes: model_dict["assigned_shapes"] |= shape_to_dict(shape) @@ -244,17 +244,17 @@ def model_to_dict(model: BaseModel) -> dict: # Store submodel connections for key in submodel._all_keys: submodel_connections.setdefault( - submodel.conns._get_metadata(key), [model_id, key] + submodel.conns.get_metadata(key), [model_id, key] ) assert isinstance(model, Model) connection_dict[model_id] = connection_to_dict( model, submodel, submodel_connections, model_id ) canonical_keys[model_id] = ( - submodel._canonical_input.key, - submodel._canonical_output.key, + submodel.canonical_input.key, + submodel.canonical_output.key, ) - canonical_keys["model"] = (model._canonical_input.key, model._canonical_output.key) + canonical_keys["model"] = (model.canonical_input.key, model._canonical_output.key) model_dict["submodels"] = submodels model_dict["connections"] = connection_dict @@ -538,8 +538,8 @@ def item_to_json(item: IOKey): # TODO: Currently type is not supported for Tensors. # Handle This whit conversion test updates. result: dict[str, Any] = {} - if not isinstance(item._value, ToBeDetermined): - result["value"] = item._value + if not isinstance(item.value, ToBeDetermined): + result["value"] = item.value if item._shape is not None: shape_template = [] for symbol in item._shape: @@ -549,10 +549,10 @@ def item_to_json(item: IOKey): shape_template.append(str(symbol)) result["shape_template"] = shape_template - elif isinstance(item._type, UnionType): - result["type"] = [type_to_str(item) for item in item._type.__args__] + elif isinstance(item.type, UnionType): + result["type"] = [type_to_str(item) for item in item.type.__args__] else: result["type"] = [ - type_to_str(item._type), + type_to_str(item.type), ] return result diff --git a/mithril/utils/utils.py b/mithril/utils/utils.py index f4cc7cb9..d74cb1d0 100755 --- a/mithril/utils/utils.py +++ b/mithril/utils/utils.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Iterator, KeysView, MutableMapping +from collections.abc import Callable, Iterable, Iterator, MutableMapping from enum import Enum, IntEnum from itertools import compress from typing import Any, Generic, TypeVar @@ -40,39 +40,6 @@ class PaddingType(IntEnum): SAME = 1 -def topological_sort_dfs(graph: dict[Any, set | KeysView]) -> list: - """Finds topological sort using Depth-first search. - - Parameters - ---------- - graph : dict[Any, set | KeysView] - Dictionary which contains graph of the node relations. - - Returns - ------- - list - List of topologically sorted nodes. - """ - # Insert "m_start" node as the start node for dfs. - graph["m_start"] = graph.keys() - # graph["m_start"] = set(graph.keys()) - seen = set() - stack: list[str] = [] # path variable is gone, stack and order are new - order = [] # order will be in reverse order at first - q = ["m_start"] - while q: - v = q.pop() - if v not in seen: - seen.add(v) # no need to append to path any more - q.extend(graph[v]) - while stack and v not in graph[stack[-1]]: # new stuff here! - order.append(stack.pop()) - stack.append(v) - # Remove "m_start" node before returning the topology. - raise Exception("This function will be replaced by 'topological_order_sorting'.") - return (stack + order[::-1])[1:] - - def convert_specs_to_dict(specs): """This function converts given specs to dict which could be saved to JSON. @@ -148,21 +115,21 @@ def symmetric_difference_update(self, iterable: Iterable[T]) -> None: else: self.add(item) - def union(self, *iterables: Iterable) -> OrderedSet[T]: + def union(self, *iterables: Iterable[T]) -> OrderedSet[T]: new_set = OrderedSet(self) new_set.update(*iterables) return new_set - def difference(self, *iterables: Iterable) -> OrderedSet[T]: + def difference(self, *iterables: Iterable[T]) -> OrderedSet[T]: new_set = OrderedSet(self) new_set.difference_update(*iterables) return new_set - def intersection(self, *iterables: Iterable) -> OrderedSet[T]: + def intersection(self, *iterables: Iterable[T]) -> OrderedSet[T]: common_items = set(self._data).intersection(*iterables) return OrderedSet(common_items) - def symmetric_difference(self, iterable: Iterable) -> OrderedSet[T]: + def symmetric_difference(self, iterable: Iterable[T]) -> OrderedSet[T]: new_set = OrderedSet(self) new_set.symmetric_difference_update(iterable) return new_set @@ -205,24 +172,20 @@ def __ior__(self, other: OrderedSet[T]) -> OrderedSet[T]: self._data |= other._data return self - def __ror__(self, other): + def __ror__(self, other: OrderedSet[T]) -> OrderedSet[T]: return self.union(other) - def __and__(self, other): + def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]: return self.intersection(other) - def __sub__(self, other): + def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]: return self.difference(other) - def __xor__(self, other): + def __xor__(self, other: OrderedSet[T]) -> OrderedSet[T]: return self.symmetric_difference(other) -K = TypeVar("K") -V = TypeVar("V") - - -class BiMap(MutableMapping[K, V]): +class BiMap[K, V](MutableMapping[K, V]): # Implements a bi-directional map for storing unique keys/values using two # dictionaries. # TODO: override __reversed__ for BiMap diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 3fddfc3d..13c30422 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -57,7 +57,7 @@ def evaluate_case( model = finalize_model(current_case) # Convert static keys to array if they are not scalar. for key, value in static_keys.items(): - if isinstance(model.conns._get_metadata(key).data, Scalar): + if isinstance(model.conns.get_metadata(key).data, Scalar): static_keys[key] = value else: static_keys[key] = convert_to_array(backend, value) @@ -189,16 +189,16 @@ def partial_fun(*args): def assert_models_equal(model1: BaseModel, model2: BaseModel): - model1_keys = model1._generate_keys() - model2_keys = model2._generate_keys() + model1_keys = model1.generate_keys() + model2_keys = model2.generate_keys() if model1.canonical_input is not None and model2.canonical_input is not None: assert model1_keys.get( - key := model1._canonical_input.key, key - ) == model2_keys.get(key := model2._canonical_input.key, key) + key := model1.canonical_input.key, key + ) == model2_keys.get(key := model2.canonical_input.key, key) assert model1_keys.get( - key := model1._canonical_output.key, key - ) == model2_keys.get(key := model2._canonical_output.key, key) + key := model1.canonical_output.key, key + ) == model2_keys.get(key := model2.canonical_output.key, key) # NOTE: Below assertions will be uncommented after converting # model's dag from topological order to insertion order. @@ -246,7 +246,7 @@ def assert_models_equal(model1: BaseModel, model2: BaseModel): == conn2[1].metadata.shape.get_shapes() ) - if conn1[1].key in model1._input_keys | model1.conns.output_keys: + if conn1[1].key in model1.input_keys | model1.conns.output_keys: assert model1_keys.get(key := conn1[1].key, key) == model2_keys.get( key := conn2[1].key, key ) diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 3aee00d6..952b0d5e 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -2538,7 +2538,7 @@ def test_valued_conns_elevated_with_iokey(): ) # Note that string naming does not cause the connection # to be elevated as input to the upper level model. - assert model._input_keys == {"input", "start_dim"} + assert model.input_keys == {"input", "start_dim"} assert model.conns.latent_input_keys == {"end_dim"} @@ -2554,7 +2554,7 @@ def test_valued_conns_elevated_with_unexposed_iokey(): ) # Note that string naming does not cause the connection # to be elevated as input to the upper level model. - assert model._input_keys == {"input", "start_dim"} + assert model.input_keys == {"input", "start_dim"} assert model.conns.latent_input_keys == {"end_dim"} @@ -2562,7 +2562,7 @@ def test_scalar_conns_elevated_with_immediate_extend_value(): model = Model() flatten = Flatten(start_dim=TBD, end_dim=TBD) model += flatten(input="input", start_dim=0, end_dim=4, output=IOKey(name="output")) - assert len(model._input_keys) == 3 + assert len(model.input_keys) == 3 assert len(model.conns.latent_input_keys) == 0 diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index 43421af3..929a6140 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -60,8 +60,8 @@ def dummy_constraint(output: Tensor | Scalar, input: Tensor | Scalar): uniadics = [ Uniadic(val + add_val) if val is not None else Uniadic() for val in values ] - updates |= var_repr._update_uniadics(var_repr.prefix, uniadics) - updates |= var_repr._update_uniadics(var_repr.reverse, uniadics) + updates |= var_repr.update_uniadics(var_repr.prefix, uniadics) + updates |= var_repr.update_uniadics(var_repr.reverse, uniadics) updates |= var_repr.remove_variadic(uniadics) status = None not in values else: diff --git a/tests/scripts/test_constraints.py b/tests/scripts/test_constraints.py index bcd148fe..351e3031 100644 --- a/tests/scripts/test_constraints.py +++ b/tests/scripts/test_constraints.py @@ -240,12 +240,12 @@ def assert_type_results( assert updated_constraints == updated_symbols.constraints[UpdateType.TYPE] # Then check final types with the expected ref_results. for key, value in data.items(): - if isinstance(value._type, NestedListType): + if isinstance(value.type, NestedListType): result = ref_results[key] assert isinstance(result, NestedListType) - assert value._type.base_type == result.base_type + assert value.type.base_type == result.base_type else: - assert value._type == ref_results[key] + assert value.type == ref_results[key] def assert_value_results( @@ -363,7 +363,7 @@ def _assert_constraint_results( # If initial types are given, set them. if initial_types is not None: for key, type in initial_types.items(): - data[key]._type = type + data[key].type = type # If any initial values are given, set them. for key, value in initial_values.items(): diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index c0bec714..e6056dd2 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -60,7 +60,7 @@ def test_data_store_1(): pm.data_store.add_static_data(key, value) assert pm.data_store.data_values.keys() == {"input"} assert (pm.data_store.data_values[key].value == value).all() # type: ignore [union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -94,7 +94,7 @@ def test_data_store_1_numpy(): "output_cache", } assert (pm.data_store.data_values[key].value == value).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -111,7 +111,7 @@ def test_data_store_3(): assert pm.data_store.data_values.keys() == {"output_1"} assert (pm.data_store.data_values["output_1"] == backend.array(6.0)).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "input", @@ -204,7 +204,7 @@ def test_data_store_7(): assert pm.data_store.data_values.keys() == {"input"} assert (res["output"] == value).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -220,7 +220,7 @@ def test_data_store_8(): assert pm.data_store.data_values.keys() == {"output1"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"input"} @@ -237,7 +237,7 @@ def test_data_store_9(): assert pm.data_store.data_values.keys() == {"output1"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"input"} @@ -254,7 +254,7 @@ def test_data_store_10(): assert pm.data_store.data_values.keys() == {"input", "output2"} assert (pm.data_store.data_values["output2"] == backend.sigmoid(value)).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -271,7 +271,7 @@ def test_data_store_11(): assert pm.data_store.data_values.keys() == {"output1", "output3"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert (pm.data_store.data_values["output3"] == backend.sigmoid(value) + 2).all() # type: ignore[union-attr] - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "right", @@ -296,7 +296,7 @@ def test_data_store_13(): ) assert pm.data_store.data_values.keys() == {"out"} - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"left", "right"} @@ -330,7 +330,7 @@ def test_data_store_14(): constant_keys={"input1": input1, "input2": input2, "weight": weight}, ) assert pm.data_store.data_values.keys() == {"input1", "out2"} - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { @@ -384,7 +384,7 @@ def test_data_store_15(): constant_keys={"input1": input1, "input2": input2, "weight": weight}, ) assert pm.data_store.data_values.keys() == {"input1", "out2"} - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { @@ -438,13 +438,13 @@ def test_data_store_16(): "output_1_cache", "output_cache", } - assert pm.data_store._runtime_static_keys == {"input"} + assert pm.data_store.runtime_static_keys == {"input"} assert pm.data_store._intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() def test_data_store_17(): - """Check '_runtime_static_keys'""" + """Check 'runtime_static_keys'""" backend = NumpyBackend(precision=32) model = Model() model += (add := Add())(left="left") @@ -466,13 +466,13 @@ def test_data_store_17(): ) assert pm.data_store.data_values.keys() == {"output_0_cache", "output_cache"} - assert pm.data_store._runtime_static_keys == {"right"} + assert pm.data_store.runtime_static_keys == {"right"} assert pm.data_store._intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() def test_data_store_18(): - """Test infer ignore should remove from Data store '_runtime_static_keys'""" + """Test infer ignore should remove from Data store 'runtime_static_keys'""" backend = TorchBackend(precision=32) model = Model() model += (add := Add())(left="left") @@ -497,7 +497,7 @@ def test_data_store_18(): ) assert pm.data_store.data_values.keys() == set() - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -528,7 +528,7 @@ def test_data_store_19(): ) assert pm.data_store.data_values.keys() == set() - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -559,6 +559,6 @@ def test_data_store_20(): ) assert pm.data_store.data_values.keys() == {"tensor_out"} - assert pm.data_store._runtime_static_keys == set() + assert pm.data_store.runtime_static_keys == set() assert pm.data_store._intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == {"left", "output_1"} diff --git a/tests/scripts/test_differentiablity.py b/tests/scripts/test_differentiablity.py index 3f00d56d..f8ff6132 100644 --- a/tests/scripts/test_differentiablity.py +++ b/tests/scripts/test_differentiablity.py @@ -28,14 +28,14 @@ def test_data_linear_compile(): model += Linear()(input="input") backend = JaxBackend() pm = mithril.compile(model, backend) - assert "input" in pm.data_store._runtime_static_keys + assert "input" in pm.data_store.runtime_static_keys def test_convert_input_data_to_trainable(): model = Model() model += Linear()(input="input") model += Linear()(weight=model.input) # type: ignore - assert model.input.data.metadata.data._differentiable # type: ignore + assert model.input.data.metadata.data.differentiable # type: ignore def test_convert_input_data_to_trainable_compile(): @@ -47,7 +47,7 @@ def test_convert_input_data_to_trainable_compile(): pm = mithril.compile(model, backend) assert ( "input" - not in pm.data_store._runtime_static_keys | pm.data_store.cached_data.keys() + not in pm.data_store.runtime_static_keys | pm.data_store.cached_data.keys() ) @@ -55,7 +55,7 @@ def test_convert_internal_data_to_trainable(): model = Model() model += Linear()(input="internal_key") model += Linear()(input="input", output=model.internal_key) # type: ignore - assert model.internal_key.data.metadata.data._differentiable # type: ignore + assert model.internal_key.data.metadata.data.differentiable # type: ignore def test_set_values_data_and_param(): @@ -80,7 +80,7 @@ def test_match_tensor_with_value_data_and_param(): model = Model() model += model1(left="my_input") model += model2(left="my_input") - assert model.my_input.data.metadata.data._differentiable # type: ignore + assert model.my_input.data.metadata.data.differentiable # type: ignore def test_match_tensor_with_value_data_and_param_rev(): @@ -109,7 +109,7 @@ def test_non_trainability_flow_in_compile(): backend = JaxBackend() pm = mithril.compile(model, backend) - assert not pm.data_store.all_data["output"]._differentiable + assert not pm.data_store.all_data["output"].differentiable def test_non_trainability_flow_in_compile_with_data_keys_1(): @@ -123,7 +123,7 @@ def test_non_trainability_flow_in_compile_with_data_keys_1(): pm = mithril.compile( model, backend, data_keys={"input"}, constant_keys={"left": backend.array(1.0)} ) - assert not pm.data_store.all_data["output"]._differentiable + assert not pm.data_store.all_data["output"].differentiable def test_non_trainability_flow_in_compile_with_data_keys_2(): @@ -135,7 +135,7 @@ def test_non_trainability_flow_in_compile_with_data_keys_2(): backend = JaxBackend() pm = mithril.compile(model, backend, data_keys={"input"}) - assert pm.data_store.all_data["output"]._differentiable + assert pm.data_store.all_data["output"].differentiable def test_non_trainability_flow_in_compile_with_data_keys_3(): @@ -148,8 +148,8 @@ def test_non_trainability_flow_in_compile_with_data_keys_3(): backend = JaxBackend() pm = mithril.compile(model, backend, data_keys={"input"}) - assert pm.data_store.all_data["mult_out"]._differentiable - assert pm.data_store.all_data["add_out"]._differentiable + assert pm.data_store.all_data["mult_out"].differentiable + assert pm.data_store.all_data["add_out"].differentiable def test_trainability_flow_in_compile_with_trainable_keys(): @@ -164,5 +164,5 @@ def test_trainability_flow_in_compile_with_trainable_keys(): backend = JaxBackend() pm = mithril.compile(model, backend, trainable_keys={"input"}) - assert pm.data_store.all_data["mult_out"]._differentiable - assert pm.data_store.all_data["add_out"]._differentiable + assert pm.data_store.all_data["mult_out"].differentiable + assert pm.data_store.all_data["add_out"].differentiable diff --git a/tests/scripts/test_extend_template.py b/tests/scripts/test_extend_template.py index 01b0979a..413e3e98 100644 --- a/tests/scripts/test_extend_template.py +++ b/tests/scripts/test_extend_template.py @@ -1514,7 +1514,17 @@ def test_tensoritem_multiple_slice_3(): outputs = pm.evaluate() out = outputs["output"] assert isinstance(out, jnp.ndarray) - assert out.shape == (1, 6) + assert ( + out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + and out.shape == (1, 6) + ) def test_tensor_item_with_ellipsis_at_beginning(): @@ -1671,7 +1681,7 @@ def test_immediate_values_with_extend_template_and_regular_case(): big_model_2 = Model() big_model_2 += model(input="input", output="output") - assert big_model_1._input_keys == big_model_2._input_keys == {"input"} + assert big_model_1.input_keys == big_model_2.input_keys == {"input"} assert ( big_model_1.conns.latent_input_keys == big_model_1.conns.latent_input_keys diff --git a/tests/scripts/test_functions.py b/tests/scripts/test_functions.py index 55a31abb..53b139b7 100644 --- a/tests/scripts/test_functions.py +++ b/tests/scripts/test_functions.py @@ -198,7 +198,7 @@ def test_flatten_dag_1(): comp_model = mithril.compile(model=model4, backend=JaxBackend(precision=64)) flatted_primitive_model_list = [ - key.__class__ for key in comp_model._flat_graph.get_models() + key.__class__ for key in comp_model.flat_graph.get_models() ] assert flatted_primitive_model_list == [ @@ -258,7 +258,7 @@ def test_flatten_dag_2(): comp_model = mithril.compile(model=model4, backend=JaxBackend(precision=64)) flatted_primitive_model_list = [ - key.__class__ for key in comp_model._flat_graph.get_models() + key.__class__ for key in comp_model.flat_graph.get_models() ] assert flatted_primitive_model_list == [ @@ -301,7 +301,7 @@ def test_flatten_dag_3(): comp_model = mithril.compile(model=model1, backend=JaxBackend(precision=64)) flatted_primitive_model_list = [ - key.__class__ for key in comp_model._flat_graph.get_models() + key.__class__ for key in comp_model.flat_graph.get_models() ] assert flatted_primitive_model_list == [ diff --git a/tests/scripts/test_inference.py b/tests/scripts/test_inference.py index 1ad0b601..332bf5fa 100644 --- a/tests/scripts/test_inference.py +++ b/tests/scripts/test_inference.py @@ -51,7 +51,7 @@ def test_discard_keys_inference(case: str) -> None: model = finalize_model(current_case) if isinstance(model, TrainModel): - model._finalize() + model.finalize() reference_output_keys = sorted(results.get("output_keys", {})) reference_discard_keys = sorted(results.get("discard_keys", {})) @@ -73,7 +73,7 @@ def test_discard_keys_inference(case: str) -> None: discarded_keys = pm.discarded_keys output_keys = pm.output_keys - hanging_keys = pm._flat_graph.hanging_keys + hanging_keys = pm.flat_graph.hanging_keys discard_keys |= {key for key in hanging_keys if key not in pm.output_keys} assert sorted(discarded_keys) == reference_discard_keys diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index a3425b78..471fdb4e 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -66,7 +66,7 @@ def assert_model_keys( pm = mithril.compile(model=model, backend=TorchBackend(), safe_names=False) - physical_inputs = set(pm._input_keys) + physical_inputs = set(pm.input_keys) assert physical_inputs == physical_inputs_ref, "physical inputs does not match." physical_outputs = set(pm.output_keys) @@ -612,7 +612,7 @@ def test_iokey_values_10(): model = Model() sig_model_1 = Sigmoid() sig_model_2 = Sigmoid() - sig_model_1.input.data.metadata.data._type = float + sig_model_1.input.data.metadata.data.type = float model += sig_model_1(input="input", output=IOKey(name="output")) model += sig_model_2( @@ -642,7 +642,7 @@ def test_iokey_values_11(): input=IOKey(type=float, name="input"), output=IOKey(name="output2") ) - assert sig_model_1.input.data.metadata.data._type is float + assert sig_model_1.input.data.metadata.data.type is float def test_iokey_values_12(): @@ -1195,7 +1195,7 @@ def test_iokey_template_1(): ) expected_result = np.array([8.0]) - assert pm._input_keys == {"left", "right"} + assert pm.input_keys == {"left", "right"} assert pm.output_keys == ["output"] np.testing.assert_array_equal(res["output"], expected_result) @@ -1217,7 +1217,7 @@ def test_iokey_template_2(): ) expected_result = np.array([5.0]) - assert pm._input_keys == {"left", "right"} + assert pm.input_keys == {"left", "right"} assert pm.output_keys == ["output"] np.testing.assert_array_equal(res["output"], expected_result) @@ -1236,7 +1236,7 @@ def test_iokey_template_3(): res = pm.evaluate(params={"left": backend.array([2.0])}) expected_result = np.array([5.0]) - assert pm._input_keys == {"left", "input"} + assert pm.input_keys == {"left", "input"} assert pm.output_keys == ["output"] np.testing.assert_array_equal(res["output"], expected_result) @@ -1255,7 +1255,7 @@ def test_iokey_template_4(): res = pm.evaluate(params={"left": backend.ones((9, 8, 7))}) expected_result = 9 - assert pm._input_keys == {"left", "index"} + assert pm.input_keys == {"left", "index"} assert pm.output_keys == ["output"] np.testing.assert_array_equal(res["output"], expected_result) @@ -1274,7 +1274,7 @@ def test_iokey_template_5(): res = pm.evaluate(data={"left": [1, 2, 3]}) expected_result = np.array([1, 2, 3]) - assert pm._input_keys == {"left"} + assert pm.input_keys == {"left"} assert pm.output_keys == ["output"] np.testing.assert_array_equal(res["output"], expected_result) diff --git a/tests/scripts/test_key_namings.py b/tests/scripts/test_key_namings.py index 2b0841ad..54109cb6 100644 --- a/tests/scripts/test_key_namings.py +++ b/tests/scripts/test_key_namings.py @@ -29,9 +29,9 @@ def assert_keys(model, logical_ref, physical_ref, include_internals=False): - assert logical_ref == model._generate_keys(include_internals=include_internals) + assert logical_ref == model.generate_keys(include_internals=include_internals) pm = mithril.compile(model=model, backend=TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set(physical_ref) + assert set(pm.input_keys) == set(physical_ref) def test_finalize_keys_0(): @@ -45,7 +45,7 @@ def test_finalize_keys_0(): input=model.canonical_output, bias="bias_3", output=IOKey(name="output1") ) pm = mithril.compile(model, TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set( + assert set(pm.input_keys) == set( ( "weight_5", "bias_3", @@ -72,7 +72,7 @@ def test_finalize_keys_1(): model += Linear(10) model += Linear(10) pm = mithril.compile(model, TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set( + assert set(pm.input_keys) == set( ( "input", "weight_0", @@ -89,7 +89,7 @@ def test_finalize_keys_1(): ) model += Linear(10) pm = mithril.compile(model, TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set( + assert set(pm.input_keys) == set( ( "weight_3", "weight_1", @@ -112,7 +112,7 @@ def test_finalize_keys_1(): model += Linear(10) model += Linear(10) pm = mithril.compile(model, TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set( + assert set(pm.input_keys) == set( ("bias_1", "weight_2", "bias_0", "input", "bias_2", "weight_1", "weight_0") ) @@ -121,10 +121,10 @@ def test_finalize_keys_2(): model = Model() model += Linear(10) pm = mithril.compile(model, TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set(("input", "weight", "bias")) + assert set(pm.input_keys) == set(("input", "weight", "bias")) model += Linear(10) pm = mithril.compile(model, TorchBackend(), safe_names=False) - assert set(pm._input_keys) == set( + assert set(pm.input_keys) == set( ("input", "weight_0", "bias_0", "weight_1", "bias_1") ) @@ -134,7 +134,7 @@ def test_generate_input_keys_0(): model = Model() model += (lin1 := Linear(10)) model += (lin2 := Linear(10)) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_0", "$3": "$input", @@ -144,7 +144,7 @@ def test_generate_input_keys_0(): } model += (lin3 := Linear(10))(input=lin1.output) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_0", "$3": "$input", @@ -156,7 +156,7 @@ def test_generate_input_keys_0(): } model += Add()(left=lin2.output, right=lin3.output) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_0", "$3": "$input", @@ -169,7 +169,7 @@ def test_generate_input_keys_0(): # Extend from input model += Linear(10)(input="", output=lin1.input) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_1", "$4": "$bias_1", @@ -185,15 +185,15 @@ def test_generate_input_keys_0(): def test_generate_input_keys_1(): model = Model() - # key_mappings = model._generate_keys(include_internals = False) + # key_mappings = model.generate_keys(include_internals = False) # assert key_mappings == {} model += Linear(10) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == {"$1": "$weight", "$3": "$input", "$4": "$bias"} model += Linear(10) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_0", "$3": "$input", @@ -203,7 +203,7 @@ def test_generate_input_keys_1(): } model += Linear(10) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_0", "$3": "$input", @@ -215,7 +215,7 @@ def test_generate_input_keys_1(): } model += Linear(10)(input="", output=model.canonical_input) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$14": "$weight_0", "$16": "$input", @@ -231,15 +231,15 @@ def test_generate_input_keys_1(): def test_generate_input_keys_2(): model = Model() - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == {} model += Linear(10) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == {"$1": "$weight", "$3": "$input", "$4": "$bias"} model += Linear(10) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$weight_0", "$3": "$input", @@ -249,7 +249,7 @@ def test_generate_input_keys_2(): } model += Linear(10)(input="input", weight="weight_0") - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$_weight_0", "$3": "$_input", @@ -276,7 +276,7 @@ def test_generate_input_keys_3(): model_0 += model1 model_0 += model2 model_0 += model3 - key_mappings = model_0._generate_keys(include_internals=False) + key_mappings = model_0.generate_keys(include_internals=False) assert key_mappings == { "$1": "$input", "$2": "$in_right_0", @@ -305,7 +305,7 @@ def test_generate_input_keys_4(): model_0 += Linear(10)( input=model_0.canonical_output, weight="in_left_1", bias="in_right_1" ) - key_mappings = model_0._generate_keys(include_internals=False) + key_mappings = model_0.generate_keys(include_internals=False) assert key_mappings == { "$1": "$input", "$2": "$_in_right_0", @@ -320,18 +320,18 @@ def test_generate_input_keys_5(): for _ in range(5): model += Sigmoid() model += Linear()(input=model.canonical_output, weight="input") - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == {"$1": "$_input", "$8": "$bias"} def test_generate_input_keys_6(): model = Model() model += Linear() - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == {"$1": "$weight", "$3": "$input", "$4": "$bias"} model += Linear()(input="", output=model.canonical_input) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$6": "$weight_0", "$8": "$input", @@ -341,7 +341,7 @@ def test_generate_input_keys_6(): } model += Linear()(input="", output=model.canonical_input) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$10": "$weight_0", "$12": "$input", @@ -353,7 +353,7 @@ def test_generate_input_keys_6(): } model += Linear()(input="", output=model.canonical_input) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$14": "$weight_0", "$16": "$input", @@ -367,7 +367,7 @@ def test_generate_input_keys_6(): } model += Linear()(input="", output=model.canonical_input) - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$18": "$weight_0", "$20": "$input", @@ -391,7 +391,7 @@ def test_generate_input_keys_7(): model += con_1 model += con_2 model += con_3 - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$input", "$2": "$input2_0", @@ -414,7 +414,7 @@ def test_generate_input_keys_8(): model += con_1 model += con_2 model += con_3 - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$input", "$2": "$input2_0", @@ -427,7 +427,7 @@ def test_generate_input_keys_8(): "$13": "$input4_1", "$14": "$input5", } - key_mappings = model._generate_keys(include_internals=True) + key_mappings = model.generate_keys(include_internals=True) assert key_mappings == { "$1": "$input", "$2": "$input2_0", @@ -456,7 +456,7 @@ def test_generate_input_keys_9(): model_1 += con_2 model_2 = deepcopy(model_1) model_1 += model_2 - key_mappings = model_1._generate_keys(include_internals=False) + key_mappings = model_1.generate_keys(include_internals=False) assert key_mappings == { "$1": "$input", "$4": "$__input2_0", @@ -481,7 +481,7 @@ def test_generate_input_keys_10(): model += model2 model += model3 model += model4 - key_mappings = model._generate_keys(include_internals=False) + key_mappings = model.generate_keys(include_internals=False) assert key_mappings == { "$1": "$input", "$2": "$input2_0", diff --git a/tests/scripts/test_key_values_in_init.py b/tests/scripts/test_key_values_in_init.py index a502fe1a..81e5e416 100644 --- a/tests/scripts/test_key_values_in_init.py +++ b/tests/scripts/test_key_values_in_init.py @@ -25,7 +25,7 @@ def test_directed_call_connection(): connection = add2.output # Assume this is a Connection object info = add1(left=connection, right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) assert left_info._connections == OrderedSet([connection.data]) @@ -37,7 +37,7 @@ def test_directed_call_int(): add1 = Add(left=1) info = add1(left=1, right="right") - assert info._connections["left"] == 1 + assert info.connections["left"] == 1 def test_directed_call_int_error(): @@ -56,11 +56,11 @@ def test_directed_call_str(): add1 = Add(left=1) info = add1(left="in1", right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 + assert left_info.name == "in1" + assert left_info.value == 1 def test_directed_call_iokey_value_equal(): @@ -68,11 +68,11 @@ def test_directed_call_iokey_value_equal(): iokey = ml.IOKey("in1", value=1) # value matches val in factory_inputs info = add1(left=iokey, right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 + assert left_info.name == "in1" + assert left_info.value == 1 def test_directed_call_iokey_value_not_equal(): @@ -89,11 +89,11 @@ def test_directed_call_iokey_value_tbd(): iokey = ml.IOKey("in1") # value is TBD info = add1(left=iokey, right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) - assert left_info._name == "in1" - assert left_info._value == 1 # value is set to val from factory_inputs + assert left_info.name == "in1" + assert left_info.value == 1 # value is set to val from factory_inputs def test_directed_call_connect_key_value_not_equal(): @@ -111,7 +111,7 @@ def test_directed_call_connect_key_none(): con = ml.IOKey(connections=[connection]) info = add1(left=con, right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) assert left_info._connections == OrderedSet([connection.data]) assert left_info._value == 1 # key is set to IOKey with val from factory_inputs @@ -124,7 +124,7 @@ def test_directed_call_connect_key_value_tbd(): info = add1(left=con, right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) assert left_info._connections == OrderedSet([connection.data]) assert isinstance(left_info, ml.IOKey) @@ -138,7 +138,7 @@ def test_directed_call_connect_key_value_equal(): info = add1(left=con, right="right") - left_info = info._connections["left"] + left_info = info.connections["left"] assert isinstance(left_info, ml.IOKey) assert left_info._connections == OrderedSet([connection.data]) assert left_info._value == 1 # value is set to val from factory_inputs @@ -161,7 +161,7 @@ def test_directed_call_key_not_in_kwargs(): info = add1() # No kwargs provided - assert info._connections["left"] == 1 + assert info.connections["left"] == 1 def test_directed_call_factory_val_tbd(): @@ -169,8 +169,8 @@ def test_directed_call_factory_val_tbd(): info = add1(left="in1", right="in2") - assert info._connections["left"] == "in1" - assert info._connections["right"] == "in2" + assert info.connections["left"] == "in1" + assert info.connections["right"] == "in2" def test_integration_call_arg_connection(): diff --git a/tests/scripts/test_primitive_calls.py b/tests/scripts/test_primitive_calls.py index 5281eb8f..62893857 100644 --- a/tests/scripts/test_primitive_calls.py +++ b/tests/scripts/test_primitive_calls.py @@ -92,4 +92,4 @@ def test_error_robust_power_call_threshold_input_keys(): model2 += pow2(threshold="thres") model2.set_values({"thres": 0.1}) - assert model1._input_keys == model2._input_keys + assert model1.input_keys == model2.input_keys diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index 7fff2b74..4cb4699a 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -177,7 +177,7 @@ def test_randomized(case: str) -> None: ) static_inputs[init_key] = { key: init_backend.array(value) - if isinstance(model.conns._get_metadata(key).data, Tensor) + if isinstance(model.conns.get_metadata(key).data, Tensor) else value for key, value in static_inputs[init_key].items() } @@ -226,7 +226,7 @@ def test_randomized(case: str) -> None: } static_inputs[backend.backend_type] = { key: backend.array(value) - if isinstance(model.conns._get_metadata(key).data, Tensor) + if isinstance(model.conns.get_metadata(key).data, Tensor) else value for key, value in static_inputs[init_key].items() } diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 228e5314..799eecf9 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -264,7 +264,7 @@ def test_cyclic_extension_5(): ), ) - assert set(model._input_keys) == {"input2", "input3", "input5", "input6"} + assert set(model.input_keys) == {"input2", "input3", "input5", "input6"} assert model.conns.internal_keys == {"my_input"} @@ -815,7 +815,7 @@ def test_canonical_output_exposed_2(): model1 += Linear(dimension=16)(input=model1.canonical_output, output="output1") extend_info = model1(output1="output1") - assert extend_info._connections == {"output1": "output1"} + assert extend_info.connections == {"output1": "output1"} def test_canonical_output_exposed_3(): @@ -1352,7 +1352,7 @@ def test_static_key_names_consistency(): model += Add()(left=3) pm = mithril.compile(model, TorchBackend()) - assert {"left", "right"} == pm._input_keys + assert {"left", "right"} == pm.input_keys def test_evaluate_replace(): @@ -1366,7 +1366,7 @@ def test_evaluate_replace(): jit=False, ) - assert set(comp_model._input_keys) == {"in", "for", "add"} + assert set(comp_model.input_keys) == {"in", "for", "add"} def test_evaluate_replace_2(): @@ -1390,7 +1390,7 @@ def test_evaluate_replace_2(): backend=NumpyBackend(), jit=False, ) - assert set(comp_model._input_keys) == { + assert set(comp_model.input_keys) == { "in", "for", "add", @@ -1542,7 +1542,7 @@ def test_canonic_example(): model += LeakyRelu() model += LeakyRelu() comp_model = compile(model=model, backend=NumpyBackend()) - assert set(comp_model._input_keys) == {"slope_0", "slope_1", "input"} + assert set(comp_model.input_keys) == {"slope_0", "slope_1", "input"} assert set(comp_model.output_keys) == {"output"} inputs = {"input": np.array([[2.0, -1.0]])} assert_results_equal( @@ -2479,9 +2479,9 @@ def test_static_anlaysis(): ignored_model_keys = ( comp_model.data_store.cached_data.keys() | comp_model.discarded_keys ) - ignored_output_keys = ignored_model_keys & comp_model._flat_graph.all_target_keys + ignored_output_keys = ignored_model_keys & comp_model.flat_graph.all_target_keys ignored_model_list = [ - comp_model._flat_graph.get_model(key) for key in ignored_output_keys + comp_model.flat_graph.get_model(key) for key in ignored_output_keys ] assert ignored_model_list == [add1] @@ -2501,11 +2501,9 @@ def test_static_anlaysis_1(): discarded_model_keys = ( comp_model.data_store.cached_data.keys() | comp_model.discarded_keys ) - discarded_output_keys = ( - discarded_model_keys & comp_model._flat_graph.all_target_keys - ) + discarded_output_keys = discarded_model_keys & comp_model.flat_graph.all_target_keys discarded_model_list = [ - comp_model._flat_graph.get_model(key) for key in discarded_output_keys + comp_model.flat_graph.get_model(key) for key in discarded_output_keys ] assert discarded_model_list == [add1] @@ -2529,11 +2527,9 @@ def test_static_anlaysis_2(): | comp_model.data_store.unused_keys | comp_model.discarded_keys ) - discarded_output_keys = ( - discarded_model_keys & comp_model._flat_graph.all_target_keys - ) + discarded_output_keys = discarded_model_keys & comp_model.flat_graph.all_target_keys discarded_model_list = { - comp_model._flat_graph.get_model(key) for key in discarded_output_keys + comp_model.flat_graph.get_model(key) for key in discarded_output_keys } assert len(discarded_model_list) == 2 @@ -2553,7 +2549,7 @@ def test_static_anlaysis_4(): comp_model = mithril.compile(model=model, backend=NumpyBackend()) models = {add1, add2, sum1, sub1, mul1, mat1} - assert (models - comp_model._flat_graph.nodes.keys()) == {mat1} + assert (models - comp_model.flat_graph.nodes.keys()) == {mat1} def test_prune_1(): @@ -2585,7 +2581,7 @@ def test_prune_1(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_2(): @@ -2617,7 +2613,7 @@ def test_prune_2(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_3(): @@ -2653,7 +2649,7 @@ def test_prune_3(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_4(): @@ -2685,7 +2681,7 @@ def test_prune_4(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_5(): @@ -2716,7 +2712,7 @@ def test_prune_5(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_6(): @@ -2750,7 +2746,7 @@ def test_prune_6(): expected_output_dict = {"acc": "acc", "auc": "auc"} assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_7(): @@ -2780,7 +2776,7 @@ def test_prune_7(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_8(): @@ -2809,7 +2805,7 @@ def test_prune_8(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_9(): @@ -2836,7 +2832,7 @@ def test_prune_9(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_10(): @@ -2871,7 +2867,7 @@ def test_prune_10(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_11(): @@ -2910,7 +2906,7 @@ def test_prune_11(): } assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_12(): @@ -2927,7 +2923,7 @@ def test_prune_12(): expected_output_dict = {"out_3": "out_1", "out_2": "out_1", "out_1": "out_1"} assert_connections(compiled_model, expected_connections) - assert expected_output_dict == compiled_model._flat_graph.output_dict + assert expected_output_dict == compiled_model.flat_graph.output_dict def test_prune_13(): @@ -2944,7 +2940,7 @@ def test_prune_13(): expected_output_dict = {"out_3": "out_1", "out_1": "out_1"} assert_connections(compiled_model, expected_connections) - assert expected_output_dict == compiled_model._flat_graph.output_dict + assert expected_output_dict == compiled_model.flat_graph.output_dict def test_prune_14(): @@ -2961,7 +2957,7 @@ def test_prune_14(): expected_output_dict = {"out_3": "out_1", "out_2": "out_1", "out_1": "out_1"} assert_connections(compiled_model, expected_connections) - assert expected_output_dict == compiled_model._flat_graph.output_dict + assert expected_output_dict == compiled_model.flat_graph.output_dict def test_prune_15(): @@ -2979,7 +2975,7 @@ def test_prune_15(): expected_output_dict = {"out_3": "out_3", "out_1": "out_1"} assert_connections(compiled_model, expected_connections) - assert expected_output_dict == compiled_model._flat_graph.output_dict + assert expected_output_dict == compiled_model.flat_graph.output_dict def test_prune_valued_tensor_1(): @@ -3019,7 +3015,7 @@ def test_prune_valued_tensor_2(): expected_output_dict = {"output2": "output1", "output1": "output1"} assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_valued_tensor_3(): @@ -3043,7 +3039,7 @@ def test_prune_valued_tensor_3(): expected_output_dict = {"output2": "output1", "output1": "output1"} assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_valued_tensor_4(): @@ -3069,7 +3065,7 @@ def test_prune_valued_tensor_4(): expected_output_dict = {"output2": "output2", "output1": "output1"} assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_valued_tensor_5(): @@ -3101,7 +3097,7 @@ def test_prune_valued_tensor_5(): expected_output_dict = {"out1": "out2", "out2": "out2"} assert_connections(compiled_model, expected_connections) - assert compiled_model._flat_graph.output_dict == expected_output_dict + assert compiled_model.flat_graph.output_dict == expected_output_dict def test_prune_duplicate_grad(): @@ -3364,7 +3360,7 @@ def test_replace_with_primitive_2(): assert expected_key_mapping == list(dag.values())[0] # assert {} == comp_model.non_differentiables assert set() == comp_model.data_store.all_static_keys - assert set(["query", "key", "mask", "value"]) == set(comp_model._input_keys) + assert set(["query", "key", "mask", "value"]) == set(comp_model.input_keys) assert set(["output"]) == set(comp_model.output_keys) @@ -3398,7 +3394,7 @@ def test_replace_with_primitive_3(): # assert {"q", "k", "v", "m", "output"} == comp_model.data_store.all_static_keys assert {"output"} == comp_model.data_store.all_static_keys assert {"q", "k", "v", "m"} == comp_model.data_store.unused_keys - assert set(["q", "k", "m", "v"]) == set(comp_model._input_keys) + assert set(["q", "k", "m", "v"]) == set(comp_model.input_keys) assert set(["output"]) == set(comp_model.output_keys) @@ -3429,7 +3425,7 @@ def test_replace_with_primitive_4(): assert expected_key_mapping == list(dag.values())[0] # assert {} == comp_model.non_differentiables assert {"q", "k", "m"} == comp_model.data_store.all_static_keys - assert set(["q", "k", "m", "v"]) == set(comp_model._input_keys) + assert set(["q", "k", "m", "v"]) == set(comp_model.input_keys) assert set(["output"]) == set(comp_model.output_keys) @@ -3578,7 +3574,7 @@ def create_layer( model += Linear(1) pm = mithril.compile(model=model, backend=TorchBackend(), safe_names=False) - assert set(pm._input_keys) == { + assert set(pm.input_keys) == { "weight_0", "bias_1", "bias_3", @@ -3661,9 +3657,9 @@ def test_flatgraph_4(): ) pm = mithril.compile(model=model, backend=backend) - assert pm._input_keys == {"relu_2"} - assert len(pm._flat_graph.all_source_keys) == 3 - assert len(pm._flat_graph.all_target_keys) == 3 + assert pm.input_keys == {"relu_2"} + assert len(pm.flat_graph.all_source_keys) == 3 + assert len(pm.flat_graph.all_target_keys) == 3 def test_empy_out_grad(): @@ -4084,7 +4080,7 @@ def test_dict_to_model_using_connect(): model = dict_to_model(json_model) - assert model._input_keys == {"right", "left"} + assert model.input_keys == {"right", "left"} def test_connect_composite_2_extend_from_inputs(): @@ -4276,7 +4272,7 @@ def test_connect_11(): add = Add() model += add(left=IOKey(value=TBD, name="a"), right="right") - assert model._input_keys == {"a", "right"} + assert model.input_keys == {"a", "right"} assert ( model.dag[add]["left"].key == "a" ) # Checks "a" is assigned to the right connection. @@ -4296,7 +4292,7 @@ def test_connect_12(): output=IOKey(name="out3"), ) - assert model._input_keys == {"left", "l2", "l4", "right"} + assert model.input_keys == {"left", "l2", "l4", "right"} assert ( model.dag[add3]["left"].key == "left" ) # Checks "left" is assigned to the right connection. @@ -4312,7 +4308,7 @@ def test_connect_13(): model += buf(input=IOKey(name="input", connections=[add1.left, add2.left])) model += Add()(left=add2.output, right=buf.output, output=IOKey(name="out2")) - assert model._input_keys == {"input", "l2", "l4"} + assert model.input_keys == {"input", "l2", "l4"} def test_connect_14(): @@ -4321,7 +4317,7 @@ def test_connect_14(): model += Add()(left="l3", right="l4", output=IOKey(name="out2")) model += ToTensor()(input=IOKey(value=5, name="input"), output=IOKey(name="out3")) - assert model._input_keys == {"input", "l1", "l2", "l3", "l4"} + assert model.input_keys == {"input", "l1", "l2", "l3", "l4"} def test_connect_error_1(): @@ -5099,10 +5095,8 @@ def test_dependency_map_latent_to_input(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map # Add second model with global output. model += (buff := Buffer())(output=IOKey("buff_out")) @@ -5129,10 +5123,8 @@ def test_dependency_map_latent_to_input(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map # Add third model which changes name of a latent input and # makes it a real input of the model. @@ -5166,10 +5158,8 @@ def test_dependency_map_latent_to_input(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map def test_dependency_map_1(): @@ -5196,10 +5186,8 @@ def test_dependency_map_1(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5237,10 +5225,8 @@ def test_dependency_map_1_set_outputs(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5300,10 +5286,8 @@ def test_dependency_map_2(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5365,10 +5349,8 @@ def test_dependency_map_2_set_outputs(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5420,10 +5402,8 @@ def test_dependency_map_3(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5476,10 +5456,8 @@ def test_dependency_map_3_set_outputs(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5531,10 +5509,8 @@ def test_dependency_map_4(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5587,10 +5563,8 @@ def test_dependency_map_4_set_outputs_1(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5644,10 +5618,8 @@ def test_dependency_map_4_set_outputs_2(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5711,10 +5683,8 @@ def test_dependency_map_5(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5779,10 +5749,8 @@ def test_dependency_map_5_set_outputs_1(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5847,10 +5815,8 @@ def test_dependency_map_5_set_outputs_2(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5914,10 +5880,8 @@ def test_dependency_map_6(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -5984,10 +5948,8 @@ def test_dependency_map_6_set_outputs_1(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -6052,10 +6014,8 @@ def test_dependency_map_6_set_outputs_2(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -6121,10 +6081,8 @@ def test_dependency_map_7(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -6191,10 +6149,8 @@ def test_dependency_map_7_set_outputs_1(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -6261,10 +6217,8 @@ def test_dependency_map_7_set_outputs_2(): expected_global_output_map == model.dependency_map._global_output_dependency_map ) - assert expected_local_input_map == model.dependency_map._local_input_dependency_map - assert ( - expected_local_output_map == model.dependency_map._local_output_dependency_map - ) + assert expected_local_input_map == model.dependency_map.local_input_dependency_map + assert expected_local_output_map == model.dependency_map.local_output_dependency_map assert ( expected_global_input_map_cache @@ -6544,7 +6498,7 @@ def test_discard_trainables_1(): shapes={"input": [1, 2], "sidein": [2, 3]}, ) - assert {"input"} == pm._input_keys + assert {"input"} == pm.input_keys assert {"sidein", "sideout"} == pm.discarded_keys assert pm.get_shapes(model) == { "input": [1, 2], @@ -6563,7 +6517,7 @@ def test_discard_trainables_2(): pm = compile(model, backend, shapes={"sidein": [1, 2]}) - assert {"input"} == pm._input_keys + assert {"input"} == pm.input_keys assert {"sidein", "output_0"} == pm.discarded_keys assert pm.get_shapes(model) == { "$_Sigmoid_1_output": [1, 2], @@ -6583,7 +6537,7 @@ def test_discard_trainables_3(): pm = compile(model, backend, shapes={"sidein": [1, 2]}) - assert {"input"} == pm._input_keys + assert {"input"} == pm.input_keys assert {"sidein", "output_0"} == pm.discarded_keys assert pm.get_shapes(model) == { "$_Sigmoid_1_output": [1, 2], @@ -6612,7 +6566,7 @@ def test_discard_trainables_4(): shapes={"sideout": [1, 2, 3]}, ) - assert {"input"} == pm._input_keys + assert {"input"} == pm.input_keys assert {"sidein", "output_0", "sideout"} == pm.discarded_keys assert pm.get_shapes(model) == { "$_Sigmoid_1_output": [1, 2, 3], @@ -7302,7 +7256,7 @@ def test_generate_keys_duplicates(): model2 = Model() model2 += model() - key_mappings = model2._generate_keys(include_internals=True) + key_mappings = model2.generate_keys(include_internals=True) expected_key_mappings = { "$1": "$left", "$2": "$right", diff --git a/tests/scripts/test_set_outputs.py b/tests/scripts/test_set_outputs.py index e6e0867f..b174cc2d 100644 --- a/tests/scripts/test_set_outputs.py +++ b/tests/scripts/test_set_outputs.py @@ -61,7 +61,7 @@ def test_1(): data = {"input": backend.array([[1.0, 2], [3, 4]])} # Check equality. - assert model_1._input_keys == model_2._input_keys + assert model_1.input_keys == model_2.input_keys assert model_1.output_keys == model_2.output_keys assert model_1.conns.internal_keys == model_2.conns.internal_keys compare_models(model_1, model_2, backend, data) @@ -96,7 +96,7 @@ def test_2(): data = {"input": backend.array([[1.0, 2], [3, 4]])} # Check equality. - assert model_1._input_keys == model_2._input_keys + assert model_1.input_keys == model_2.input_keys assert model_1.output_keys == model_2.output_keys assert model_1.conns.internal_keys == model_2.conns.internal_keys compare_models(model_1, model_2, backend, data) @@ -133,7 +133,7 @@ def test_3(): data = {"input": backend.array([[1.0, 2], [3, 4]])} # Check equality. - assert model_1._input_keys == model_2._input_keys + assert model_1.input_keys == model_2.input_keys assert set(model_1.output_keys) == set(model_2.output_keys) assert model_1.conns.internal_keys == model_2.conns.internal_keys compare_models(model_1, model_2, backend, data) @@ -168,7 +168,7 @@ def test_4(): data = {"input": backend.array([[1.0, 2], [3, 4]])} # Check equality. - assert model_1._input_keys == model_2._input_keys + assert model_1.input_keys == model_2.input_keys assert set(model_1.output_keys) == set(model_2.output_keys) assert model_1.conns.internal_keys == model_2.conns.internal_keys compare_models(model_1, model_2, backend, data) diff --git a/tests/scripts/test_set_types.py b/tests/scripts/test_set_types.py index 9fdcb46f..13a7821d 100644 --- a/tests/scripts/test_set_types.py +++ b/tests/scripts/test_set_types.py @@ -24,7 +24,7 @@ def test_set_types_1(): model += sig_model(input="input", output=IOKey("output")) model.set_types({"input": int}) input_data = sig_model.input.metadata.data - assert input_data._type is int + assert input_data.type is int def test_set_types_1_kwargs_arg(): @@ -33,7 +33,7 @@ def test_set_types_1_kwargs_arg(): model += sig_model(input="input", output=IOKey("output")) model.set_types(input=int) input_data = sig_model.input.metadata.data - assert input_data._type is int + assert input_data.type is int def test_set_types_2(): @@ -42,7 +42,7 @@ def test_set_types_2(): model += buffer_model(input="input", output=IOKey(name="output")) model.set_types({"input": int | bool}) input_data = buffer_model.input.metadata.data - assert input_data._type == int | bool + assert input_data.type == int | bool def test_set_types_2_kwargs_arg(): @@ -51,7 +51,7 @@ def test_set_types_2_kwargs_arg(): model += buffer_model(input="input", output=IOKey(name="output")) model.set_types(input=int | bool) input_data = buffer_model.input.metadata.data - assert input_data._type == int | bool + assert input_data.type == int | bool def test_set_types_3(): @@ -60,7 +60,7 @@ def test_set_types_3(): model += buffer_model(input="input", output=IOKey(name="output")) model.set_types({buffer_model.input: int | bool}) input_data = buffer_model.input.metadata.data - assert input_data._type == int | bool + assert input_data.type == int | bool def test_set_types_3_kwargs_arg_1(): @@ -69,7 +69,7 @@ def test_set_types_3_kwargs_arg_1(): model += buffer_model(input="input", output=IOKey(name="output")) model.set_types(input=int | bool) input_data = buffer_model.input.metadata.data - assert input_data._type == int | bool + assert input_data.type == int | bool def test_set_types_3_kwargs_arg_2(): @@ -78,7 +78,7 @@ def test_set_types_3_kwargs_arg_2(): model += buffer_model(input="input", output=IOKey(name="output")) buffer_model.set_types(input=int | bool) input_data = buffer_model.input.metadata.data - assert input_data._type == int | bool + assert input_data.type == int | bool def test_set_types_4(): @@ -87,7 +87,7 @@ def test_set_types_4(): model += buffer_model(input="input", output=IOKey(name="output")) model.set_types({model.input: int | bool}) # type: ignore input_data = buffer_model.input.metadata.data - assert input_data._type == int | bool + assert input_data.type == int | bool def test_set_types_5(): @@ -99,8 +99,8 @@ def test_set_types_5(): model.set_types({model.input1: int | bool, "input2": float}) # type: ignore input_data_1 = buffer_model_1.input.metadata.data input_data_2 = buffer_model_2.input.metadata.data - assert input_data_1._type == int | bool - assert input_data_2._type is float + assert input_data_1.type == int | bool + assert input_data_2.type is float def test_set_types_5_key_error(): @@ -132,8 +132,8 @@ def test_set_types_7(): model.set_types({"input": tuple[int, float, int]}) input_data = model.input.metadata.data # type: ignore output_data = model.output.metadata.data # type: ignore - assert input_data._type == tuple[int, float, int] - assert output_data._type is float + assert input_data.type == tuple[int, float, int] + assert output_data.type is float def test_set_types_8(): @@ -143,8 +143,8 @@ def test_set_types_8(): item_model.set_types({"input": tuple[int, float, int]}) input_data = model.input.metadata.data # type: ignore output_data = model.output.metadata.data # type: ignore - assert input_data._type == tuple[int, float, int] - assert output_data._type is float + assert input_data.type == tuple[int, float, int] + assert output_data.type is float def test_set_types_9(): @@ -161,8 +161,8 @@ def test_set_types_9(): model2.set_types({"input": tuple[int, float, int]}) input_data = model1.input.metadata.data # type: ignore output_data = model3.output.metadata.data # type: ignore - assert input_data._type == tuple[int, float, int] - assert output_data._type is float + assert input_data.type == tuple[int, float, int] + assert output_data.type is float def test_types_iokey_1(): @@ -171,8 +171,8 @@ def test_types_iokey_1(): model += buffer_model(input="input", output=IOKey(name="output", type=int)) output_data = model.output.metadata.data # type: ignore input_data = model.input.metadata.data # type: ignore - assert output_data._type is int - assert input_data._type is int + assert output_data.type is int + assert input_data.type is int def test_types_iokey_2(): @@ -185,9 +185,9 @@ def test_types_iokey_2(): output_data = model.output2.metadata.data # type: ignore edge_data = model.output.metadata.data # type: ignore input_data = model.input.metadata.data # type: ignore - assert output_data._type is int - assert input_data._type is int - assert edge_data._type is int + assert output_data.type is int + assert input_data.type is int + assert edge_data.type is int def test_types_iokey_3(): @@ -215,11 +215,11 @@ def test_types_iokey_3(): buffer3_input = buffer_model3.input.metadata.data buffer3_output = buffer_model3.output.metadata.data - assert buffer1_input._type is float - assert buffer1_output._type is float + assert buffer1_input.type is float + assert buffer1_output.type is float - assert buffer2_input._type is float - assert buffer2_output._type is float + assert buffer2_input.type is float + assert buffer2_output.type is float - assert buffer3_input._type is float - assert buffer3_output._type is float + assert buffer3_input.type is float + assert buffer3_output.type is float diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index b18dc97e..2b8a0526 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -274,7 +274,7 @@ def assert_match_shapes( uni_cache: dict[UniadicRecord, str] = {} var_cache: dict[Variadic, str] = {} - repr1._match(repr2) + repr1.match(repr2) ref_shapes = {"repr1": repr1_ref_shapes, "repr2": repr2_ref_shapes} @@ -1370,7 +1370,7 @@ def test_composite_1_extend_inputs_1(): composite += (m2 := Multiply())(left=m1.right, right=m1.output) composite += Add()(left=m2.output, right=m2.output, output=IOKey(name="output")) composite.set_canonical_input(m1.left) - key_mappings = composite._generate_keys() + key_mappings = composite.generate_keys() m1_out_metadata = composite.conns.get_con_by_metadata(m1.output.metadata) assert m1_out_metadata is not None @@ -2271,7 +2271,7 @@ def test_composite_3_extend_shapes_1(): output=IOKey(name="output"), ) - key_mappings = composite_3._generate_keys() + key_mappings = composite_3.generate_keys() composite_3_left_metadata = composite_3.conns.get_con_by_metadata(m1.left.metadata) # type: ignore assert composite_3_left_metadata is not None @@ -3140,7 +3140,7 @@ def test_shape_3(): model += two_buff_model(input1="input1", output2=IOKey(name="output2")) buff1 = Buffer() model += buff1(input=two_buff_model.output1, output=two_buff_model.input2) # type: ignore - model._generate_keys() + model.generate_keys() buff1.set_shapes({"input": [3, 4, 5, 6]}) logical_ref = { "input1": [3, 4, 5, 6], @@ -5986,7 +5986,7 @@ def test_cartesian_call(): output=IOKey(name="output"), ) - key_mappings = model3._generate_keys() + key_mappings = model3.generate_keys() model_1_out1 = key_mappings[ model3.conns.get_con_by_metadata(model1.output1.metadata).key # type: ignore ] @@ -7915,7 +7915,7 @@ def test_uniadic_repr_count_4(): model += Buffer()(input="input6", output=IOKey(name="output6")) main_model = Model() - main_model += (model1 := deepcopy(model))(**{key: key for key in model._input_keys}) + main_model += (model1 := deepcopy(model))(**{key: key for key in model.input_keys}) main_model += (model2 := deepcopy(model))( input1=model1.output1, # type: ignore @@ -8707,7 +8707,7 @@ def test_possible_variadic_values_14(): PossibleValues((Uniadic(6), Uniadic())), PossibleValues((Uniadic(3), Uniadic(), Uniadic())), ) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == repr2.get_shapes() == [2, 6, 4] @@ -8738,7 +8738,7 @@ def test_possible_variadic_values_14_1(): ) uni = Uniadic(2) repr2 = ShapeRepr(root=var2, prefix=[uni]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == repr2.get_shapes() == [2, 6, 4] @@ -8769,7 +8769,7 @@ def test_possible_variadic_values_14_2(): ) uni = Uniadic({2, 3}) repr2 = ShapeRepr(root=var2, prefix=[uni]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == repr2.get_shapes() == [2, 6, 4] @@ -8801,7 +8801,7 @@ def test_possible_variadic_values_14_2_1(): ) uni = Uniadic({2, 3}) repr2 = ShapeRepr(root=var2, prefix=[uni]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == repr2.get_shapes() == [2, 6, 4] @@ -8832,7 +8832,7 @@ def test_possible_variadic_values_14_2_2(): ) uni = Uniadic({2, 3, 10}) repr2 = ShapeRepr(root=var2, prefix=[uni]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == repr2.get_shapes() == ["u1", 6, 4] assert uni.possible_values == {2, 10} @@ -8865,7 +8865,7 @@ def test_possible_variadic_values_14_3(): repr2 = ShapeRepr(root=var2, prefix=[uni]) with pytest.raises(ValueError) as err_info: - repr1._match(repr2) + repr1.match(repr2) assert str(err_info.value) == "Incompatible possible values for Variadic!" @@ -8895,7 +8895,7 @@ def test_possible_variadic_values_14_4(): repr2 = ShapeRepr(root=var2, prefix=[uni]) with pytest.raises(ValueError) as err_info: - repr1._match(repr2) + repr1.match(repr2) assert str(err_info.value) == "Incompatible possible values for Variadic!" @@ -8937,7 +8937,7 @@ def test_possible_variadic_values_15(): uni2 = Uniadic() uni3 = Uniadic() repr2 = ShapeRepr(root=var2, prefix=[uni1], suffix=[uni2, uni3]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == repr2.get_shapes() == [10, 11, 12, 13, 14, 15] @@ -8963,7 +8963,7 @@ def test_possible_variadic_values_16(): ) repr1 = ShapeRepr(root=var) - repr1._match(repr2) + repr1.match(repr2) assert repr1[0].value == 2 assert repr1.get_shapes() == [2, 1] @@ -8992,7 +8992,7 @@ def test_possible_variadic_values_17(): repr1 = ShapeRepr(prefix=[], root=var1, suffix=[]) repr2 = ShapeRepr(prefix=[], root=var2, suffix=[]) - repr1._match(repr2) + repr1.match(repr2) assert repr2.get_shapes() == repr1.get_shapes() == [1, 2, 3] @@ -9014,7 +9014,7 @@ def test_possible_variadic_values_18(): repr2 = ShapeRepr(prefix=[], root=var, suffix=[]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [4, 2, 3, 2] assert repr2.get_shapes() == [4, 2, 3, 2] @@ -9036,7 +9036,7 @@ def test_possible_variadic_values_19(): ) repr2 = ShapeRepr(prefix=[], root=var, suffix=[]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [4, 2, 3, 2] assert repr2.get_shapes() == [4, 2, 3, 2] @@ -9058,7 +9058,7 @@ def test_possible_variadic_values_20(): ) repr2 = ShapeRepr(prefix=[], root=var, suffix=[]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [5, 2, 3, 2] assert repr2.get_shapes() == [5, 2, 3, 2] @@ -9083,7 +9083,7 @@ def test_possible_variadic_values_21(): repr2 = ShapeRepr(prefix=[], root=var2, suffix=[]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [5, 6, 9] assert repr2.get_shapes() == [5, 6, 9] @@ -9113,7 +9113,7 @@ def test_possible_variadic_values_22(): repr2 = ShapeRepr(prefix=[], root=var2, suffix=[]) - repr1._match(repr2) + repr1.match(repr2) assert repr2.get_shapes() == repr1.get_shapes() == [5, 6, 9] assert uni1.possible_values == {9} @@ -9165,7 +9165,7 @@ def test_possible_variadic_values_23(): repr1 = ShapeRepr(prefix=[a], root=V1, suffix=[]) repr2 = ShapeRepr(prefix=[], root=V2, suffix=[b]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [3, 9, 6] assert repr2.get_shapes() == [3, 9, 6] @@ -9220,7 +9220,7 @@ def test_possible_variadic_values_23_1(): repr1 = ShapeRepr(prefix=[a], root=V1, suffix=[]) repr2 = ShapeRepr(prefix=[], root=V2, suffix=[b]) - updates = repr1._match(repr2) + updates = repr1.match(repr2) updates |= b.update_possible_values({5, 6}) @@ -9250,7 +9250,7 @@ def test_possible_variadic_values_24(): repr1 = ShapeRepr(prefix=[a], root=V1, suffix=[]) repr2 = ShapeRepr(prefix=[], root=V2, suffix=[b]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [2, 5] assert repr2.get_shapes() == [2, 5] @@ -9271,7 +9271,7 @@ def test_possible_variadic_values_26(): repr1 = ShapeRepr(prefix=[a], root=Variadic(), suffix=[]) repr2 = ShapeRepr(prefix=[], root=Variadic(), suffix=[b]) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == ["u1", "(V1, ...)", "u2"] assert repr2.get_shapes() == ["u1", "(V1, ...)", "u2"] @@ -9297,7 +9297,7 @@ def test_possible_variadic_values_27(): repr1 = ShapeRepr(prefix=[a, a], root=None, suffix=[]) repr2 = ShapeRepr(prefix=[], root=v1, suffix=[]) - repr2._match(repr1) + repr2.match(repr1) assert repr1.get_shapes() == [2, 2] assert repr2.get_shapes() == [2, 2] @@ -9323,7 +9323,7 @@ def test_possible_variadic_values_28(): repr1 = ShapeRepr(prefix=[], root=v2, suffix=[]) repr2 = ShapeRepr(prefix=[], root=v1, suffix=[]) - repr2._match(repr1) + repr2.match(repr1) assert repr1.get_shapes() == [2, 3, 4] assert repr2.get_shapes() == [2, 3, 4] @@ -9350,7 +9350,7 @@ def test_possible_variadic_values_29(): ) repr2 = ShapeRepr(root=var) - repr1._match(repr2) + repr1.match(repr2) dnf1 = DNF([AND({uni1: uni7})]) dnf2 = DNF([AND({uni2: uni8})]) @@ -9441,7 +9441,7 @@ def test_impossible_variadic_values_1(): V2.update_possible_values(PossibleValues((Uniadic(3), Uniadic({9, 10})))) - repr1._match(repr2) + repr1.match(repr2) assert repr1.get_shapes() == [3, 9, 6] assert repr2.get_shapes() == [3, 9, 6] diff --git a/tests/scripts/test_train_context.py b/tests/scripts/test_train_context.py index 2e6cd97b..8860f575 100644 --- a/tests/scripts/test_train_context.py +++ b/tests/scripts/test_train_context.py @@ -120,7 +120,7 @@ def test_add_loss_case_2_exception_2(): ) # Finalize train model. - ctx1._finalize() + ctx1.finalize() with pytest.raises(Exception) as err_info: ctx1.add_loss( Subtract(), @@ -372,7 +372,7 @@ def test_add_regularization_case_1(): linear_1 = Linear() model += linear_1( output="output", - **{key: key for key in linear_1._input_keys if not key.startswith("$")}, + **{key: key for key in linear_1.input_keys if not key.startswith("$")}, ) ctx = TrainModel(model) ctx.add_regularization(model=L2(), coef=1e-1, input="weight") @@ -389,7 +389,7 @@ def test_add_regularization_case_2(): model.extend( linear_1, output="output", - **{key: key for key in linear_1._input_keys if not key.startswith("$")}, + **{key: key for key in linear_1.input_keys if not key.startswith("$")}, ) ctx = TrainModel(model) with pytest.raises(KeyError) as err_info: @@ -406,7 +406,7 @@ def test_add_regularization_case_3(): linear_1 = Linear() model += linear_1( output="output", - **{key: key for key in linear_1._input_keys if not key.startswith("$")}, + **{key: key for key in linear_1.input_keys if not key.startswith("$")}, ) ctx = TrainModel(model) @@ -421,7 +421,7 @@ def test_add_regularization_case_6(): linear_1 = Linear() model += linear_1( output="output", - **{key: key for key in linear_1._input_keys if not key.startswith("$")}, + **{key: key for key in linear_1.input_keys if not key.startswith("$")}, ) ctx = TrainModel(model) @@ -436,7 +436,7 @@ def test_add_regularization_case_7(): linear_1 = Linear() model += linear_1( output="output", - **{key: key for key in linear_1._input_keys if not key.startswith("$")}, + **{key: key for key in linear_1.input_keys if not key.startswith("$")}, ) ctx = TrainModel(model) @@ -459,7 +459,7 @@ def test_add_regularization_case_7_exception(): linear_1 = Linear() model += linear_1( output=IOKey(name="output"), - **{key: key for key in linear_1._input_keys if not key.startswith("$")}, + **{key: key for key in linear_1.input_keys if not key.startswith("$")}, ) ctx = TrainModel(model) @@ -467,7 +467,7 @@ def test_add_regularization_case_7_exception(): ctx.add_loss(SquaredError(), input="output", target="target") # Finalize train model. - ctx._finalize() + ctx.finalize() with pytest.raises(Exception) as err_info: ctx.add_regularization( model=QuadraticFormRegularizer(), coef=1e-1, input="weight" diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index 72fe06f4..76e4a5d6 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -568,9 +568,9 @@ def test_type_propagation_1(): right=IOKey(value=2, name="right"), output=IOKey(name="output"), ) - assert model.left.metadata.data._type is int # type: ignore - assert model.right.metadata.data._type is int # type: ignore - assert model.output.metadata.data._type is int # type: ignore + assert model.left.metadata.data.type is int # type: ignore + assert model.right.metadata.data.type is int # type: ignore + assert model.output.metadata.data.type is int # type: ignore def test_type_propagation_2(): @@ -579,9 +579,9 @@ def test_type_propagation_2(): model += Add()( left=IOKey(value=1, name="left"), right="right", output=IOKey(name="output") ) - assert model.left.metadata.data._type is int # type: ignore - assert model.right.metadata.data._type == float | int | bool # type: ignore - assert model.output.metadata.data._type == float | int # type: ignore + assert model.left.metadata.data.type is int # type: ignore + assert model.right.metadata.data.type == float | int | bool # type: ignore + assert model.output.metadata.data.type == float | int # type: ignore def test_type_propagation_3(): @@ -590,9 +590,9 @@ def test_type_propagation_3(): model += Add()( left=IOKey(value=1.0, name="left"), right="right", output=IOKey(name="output") ) - assert model.left.metadata.data._type is float # type: ignore - assert model.right.metadata.data._type == float | int | bool # type: ignore - assert model.output.metadata.data._type is float # type: ignore + assert model.left.metadata.data.type is float # type: ignore + assert model.right.metadata.data.type == float | int | bool # type: ignore + assert model.output.metadata.data.type is float # type: ignore def test_type_propagation_4(): @@ -604,9 +604,9 @@ def test_type_propagation_4(): right="right", output=IOKey(name="output"), ) - assert add.left.metadata.data._type is bool - assert model.right.metadata.data._type == float | int | bool # type: ignore - assert model.output.metadata.data._type == float | int | bool # type: ignore + assert add.left.metadata.data.type is bool + assert model.right.metadata.data.type == float | int | bool # type: ignore + assert model.output.metadata.data.type == float | int | bool # type: ignore def test_type_propagation_5(): @@ -619,9 +619,9 @@ def test_type_propagation_5(): output=IOKey(name="output"), ) - assert add.left.metadata.data._type is bool - assert add.right.metadata.data._type is int - assert model.output.metadata.data._type is int # type: ignore + assert add.left.metadata.data.type is bool + assert add.right.metadata.data.type is int + assert model.output.metadata.data.type is int # type: ignore def test_type_propagation_6(): @@ -634,9 +634,9 @@ def test_type_propagation_6(): output=IOKey(name="output"), ) - assert add.left.metadata.data._type is bool - assert add.right.metadata.data._type is float - assert model.output.metadata.data._type is float # type: ignore + assert add.left.metadata.data.type is bool + assert add.right.metadata.data.type is float + assert model.output.metadata.data.type is float # type: ignore def test_type_propagation_7(): @@ -649,9 +649,9 @@ def test_type_propagation_7(): output=IOKey(name="output"), ) - assert add.left.metadata.data._type is int - assert add.right.metadata.data._type is float - assert model.output.metadata.data._type is float # type: ignore + assert add.left.metadata.data.type is int + assert add.right.metadata.data.type is float + assert model.output.metadata.data.type is float # type: ignore class ArtificialPrimitive(PrimitiveModel): @@ -679,15 +679,15 @@ def artificial_constraint(cls, output: Tensor, input: Tensor): status = False updates = Updates() # Reverse inference - if not isinstance(output._type, UnionType): - # update_type(input, output._type, updates) - input.set_type(output._type) + if not isinstance(output.type, UnionType): + # update_type(input, output.type, updates) + input.set_type(output.type) # updates.add(input, UpdateType._TYPE) status = True # Forward inference elif not isinstance(input, UnionType): - # update_type(output, input._type, updates) - output.set_type(input._type) + # update_type(output, input.type, updates) + output.set_type(input.type) # updates.add(output, UpdateType._TYPE) status = True return status, updates @@ -701,10 +701,10 @@ def test_type_propagation_8(): primitive = ArtificialPrimitive(type=MyTensor[int] | MyTensor[bool]) model += primitive(input=add.output, output=IOKey(name="output")) - assert add.left.metadata.data._type is int - assert add.right.metadata.data._type == int | bool - assert add.output.metadata.data._type is int - assert model.output.metadata.data._type is int # type: ignore + assert add.left.metadata.data.type is int + assert add.right.metadata.data.type == int | bool + assert add.output.metadata.data.type is int + assert model.output.metadata.data.type is int # type: ignore def test_type_propagation_9(): @@ -715,9 +715,9 @@ def test_type_propagation_9(): model += add(left=IOKey(value=[1], name="left"), right="right") model += tensor_to_list(input=add.output, output=IOKey(name="output")) - assert add.left.metadata.data._type is int - assert add.right.metadata.data._type is float - assert model.output.metadata.data._type is float # type: ignore + assert add.left.metadata.data.type is int + assert add.right.metadata.data.type is float + assert model.output.metadata.data.type is float # type: ignore def test_type_propagation_10(): @@ -728,9 +728,9 @@ def test_type_propagation_10(): model += add(left="right", right="right") model += tensor_to_list(input=add.output, output=IOKey(name="output")) - assert add.left.metadata.data._type == int | bool - assert add.right.metadata.data._type == int | bool - assert model.output.metadata.data._type == int | bool # type: ignore + assert add.left.metadata.data.type == int | bool + assert add.right.metadata.data.type == int | bool + assert model.output.metadata.data.type == int | bool # type: ignore def test_type_propagation_floor_divide_1(): @@ -743,9 +743,9 @@ def test_type_propagation_floor_divide_1(): numerator=add.left, denominator=add.output, output=IOKey(name="output") ) - assert add.left.metadata.data._type is int - assert add.right.metadata.data._type == float | int | bool - assert model.output.metadata.data._type == float | int # type: ignore + assert add.left.metadata.data.type is int + assert add.right.metadata.data.type == float | int | bool + assert model.output.metadata.data.type == float | int # type: ignore def test_type_propagation_floor_divide_2(): @@ -758,11 +758,11 @@ def test_type_propagation_floor_divide_2(): model += floor_div(numerator=add.left, denominator=add.output) model += ap(input=floor_div.output, output=IOKey(name="output")) - assert add.left.metadata.data._type is int - assert add.right.metadata.data._type == int | bool - assert floor_div.denominator.metadata.data._type is int - assert floor_div.output.metadata.data._type is int - assert model.output.metadata.data._type is int # type: ignore + assert add.left.metadata.data.type is int + assert add.right.metadata.data.type == int | bool + assert floor_div.denominator.metadata.data.type is int + assert floor_div.output.metadata.data.type is int + assert model.output.metadata.data.type is int # type: ignore def test_type_propagation_floor_divide_3(): @@ -775,11 +775,11 @@ def test_type_propagation_floor_divide_3(): model += floor_div(numerator=add.left, denominator=add.output) model += ap(input=floor_div.output, output=IOKey(name="output")) - assert add.left.metadata.data._type is int - assert add.right.metadata.data._type == float | int | bool - assert floor_div.denominator.metadata.data._type == float | int - assert floor_div.output.metadata.data._type == float | int - assert model.output.metadata.data._type == float | int # type: ignore + assert add.left.metadata.data.type is int + assert add.right.metadata.data.type == float | int | bool + assert floor_div.denominator.metadata.data.type == float | int + assert floor_div.output.metadata.data.type == float | int + assert model.output.metadata.data.type == float | int # type: ignore def test_type_propagation_floor_divide_4(): @@ -883,7 +883,7 @@ def test_type_initialization_1(): model = Model() model += LeakyRelu()(slope=IOKey("slope", 0.5)) - assert model.slope.metadata.data._type is float # type: ignore + assert model.slope.metadata.data.type is float # type: ignore def test_connect_1(): diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 5a6eb966..ac7ac596 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -180,7 +180,7 @@ def test_type_1(): input=[[[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]]], ) - assert shape1.output.data.metadata.data._type == tuple[int, int] + assert shape1.output.data.metadata.data.type == tuple[int, int] def test_type_2(): @@ -194,7 +194,7 @@ def test_type_2(): model += shape3(input=[[1, 2, 4], [3, 5, 7]]) model += union1(input1=shape1.output, input2=shape2.output, input3=shape3.output) - assert shape1.output.data.metadata.data._type == tuple[int, int] + assert shape1.output.data.metadata.data.type == tuple[int, int] def test_type_3(): @@ -205,11 +205,11 @@ def test_type_3(): shape3 = Shape() model += union1() input1 = union1.input1 # type: ignore - assert input1.data.metadata.data._type == int | float | tuple[int | float, ...] + assert input1.data.metadata.data.type == int | float | tuple[int | float, ...] model += shape1(input=[[1, 2, 4], [3, 5, 7]], output=input1) model += shape2(input=[[1, 2, 4], [3, 5, 7]], output=union1.input2) # type: ignore model += shape3(input=[[1, 2, 4], [3, 5, 7]], output=union1.input3) # type: ignore - assert input1.data.metadata.data._type == tuple[int, int] + assert input1.data.metadata.data.type == tuple[int, int] def test_type_5(): @@ -223,7 +223,7 @@ def test_type_5(): input=[[[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]]], ) model += conv1(input="", stride=shape1.output) - assert shape1.output.data.metadata.data._type == tuple[int, int] + assert shape1.output.data.metadata.data.type == tuple[int, int] def test_type_6(): @@ -246,9 +246,9 @@ def test_type_7(): test_model_3 = Model2() model += test_model_1(input1="input1", input2="input2", input3="input3") input1 = model.input1 # type: ignore - assert input1.data.metadata.data._type == int | float + assert input1.data.metadata.data.type == int | float model += test_model_2(input1="", input2="input1") - assert input1.data.metadata.data._type is int + assert input1.data.metadata.data.type is int with pytest.raises(TypeError) as err_info: model += test_model_3(input1="", input3="input1") assert ( @@ -264,9 +264,9 @@ def test_type_8(): model3 = Model3() model += model3(input1="input1", input2="input1", input3="input1", output="output") input1 = model.input1 # type: ignore - assert input1.data.metadata.data._type == tuple[int, int, int, int] + assert input1.data.metadata.data.type == tuple[int, int, int, int] model += model1(input1="input1") - assert input1.data.metadata.data._type == tuple[int, int, int, int] + assert input1.data.metadata.data.type == tuple[int, int, int, int] with pytest.raises(TypeError) as err_info: model += model2(input1="input1") assert str(err_info.value) == ( @@ -278,59 +278,59 @@ def test_type_8(): def test_type_9(): model = Model() lin_model = Linear() - assert lin_model.input.data.metadata.data._type == int | float | bool + assert lin_model.input.data.metadata.data.type == int | float | bool model += lin_model( input=IOKey(value=[[1.0, 2.0], [3.0, 4.0]], name="input"), weight="w", bias="b", output=IOKey(name="output"), ) - assert lin_model.input.data.metadata.data._type is float + assert lin_model.input.data.metadata.data.type is float def test_type_10(): model = Model() lin_model = Linear() - assert lin_model.input.data.metadata.data._type == int | float | bool + assert lin_model.input.data.metadata.data.type == int | float | bool model += lin_model( input=IOKey(value=[[False, 1], [True, False]], name="input"), weight="w", bias="b", output=IOKey(name="output"), ) - assert lin_model.input.data.metadata.data._type is int + assert lin_model.input.data.metadata.data.type is int def test_type_11(): model = Model() lin_model = Linear() - assert lin_model.input.data.metadata.data._type == int | float | bool + assert lin_model.input.data.metadata.data.type == int | float | bool model += lin_model( input=IOKey(value=[[False, 1], [2.2, False]], name="input"), weight="w", bias="b", output=IOKey(name="output"), ) - assert lin_model.input.data.metadata.data._type is float + assert lin_model.input.data.metadata.data.type is float def test_type_12(): model = Model() lin_model = Linear() - assert lin_model.input.data.metadata.data._type == int | float | bool + assert lin_model.input.data.metadata.data.type == int | float | bool model += lin_model( input=IOKey(value=[[False, 1], [2.2, False]], name="input"), weight="w", bias="b", output=IOKey(name="output"), ) - assert lin_model.input.data.metadata.data._type is float + assert lin_model.input.data.metadata.data.type is float def test_type_13(): model = Model() lin_model = Linear() - assert lin_model.input.data.metadata.data._type == int | float | bool + assert lin_model.input.data.metadata.data.type == int | float | bool model += lin_model( input=IOKey(value=[[False, True], [False, False]], name="input"), weight="w", @@ -338,27 +338,27 @@ def test_type_13(): output=IOKey(name="output"), ) # model.make_static("input", [[False, True], [False, False]]) - assert lin_model.input.data.metadata.data._type is bool + assert lin_model.input.data.metadata.data.type is bool def test_type_14(): model = Model() lin_model = Linear() - assert lin_model.input.data.metadata.data._type == int | float | bool + assert lin_model.input.data.metadata.data.type == int | float | bool model += lin_model( input=IOKey(value=[[False, 1.0], [2, 3]], name="input"), weight="w", bias="b", output=IOKey(name="output"), ) - assert lin_model.input.data.metadata.data._type is float + assert lin_model.input.data.metadata.data.type is float def test_type_15(): model = Model() sig_model = Sigmoid() sig_model_2 = Sigmoid() - sig_model_2.input.data.metadata.data._type = float + sig_model_2.input.data.metadata.data.type = float model += sig_model(input="input", output=IOKey(name="output")) model += sig_model_2( @@ -382,7 +382,7 @@ def test_type_16(): model = Model() sig_model_1 = Sigmoid() sig_model_2 = Sigmoid() - sig_model_1.input.data.metadata.data._type = float + sig_model_1.input.data.metadata.data.type = float model += sig_model_1(input="input", output=IOKey(name="output")) with pytest.raises(TypeError) as err_info: @@ -400,7 +400,7 @@ def test_type_17(): model = Model() sig_model_1 = Sigmoid() sig_model_2 = Sigmoid() - sig_model_1.input.data.metadata.data._type = float + sig_model_1.input.data.metadata.data.type = float model.extend(sig_model_1, input="input", output="output") with pytest.raises(TypeError) as err_info: model.extend( diff --git a/tests/scripts/test_utils.py b/tests/scripts/test_utils.py index 77b2182e..c1b927dd 100644 --- a/tests/scripts/test_utils.py +++ b/tests/scripts/test_utils.py @@ -272,7 +272,7 @@ def get_all_data(model: BaseModel) -> set[Scalar | Tensor]: def get_all_metadata(model: BaseModel) -> set[IOHyperEdge | None]: # recursively gets the all metadata in the model (IOHyperEdge) if isinstance(model, PrimitiveModel): - return {model.conns._get_metadata(key) for key in model.conns.all} + return {model.conns.get_metadata(key) for key in model.conns.all} assert isinstance(model, Model) data = set() for submodel in model.dag: @@ -526,13 +526,13 @@ def assert_connections( compiled_model: PhysicalModel, expected_connections: dict[str, list[str | set[str]]] ): result_connections = {} - for key in compiled_model._flat_graph.all_target_keys: - if key not in compiled_model._flat_graph.connections: + for key in compiled_model.flat_graph.all_target_keys: + if key not in compiled_model.flat_graph.connections: continue - node = compiled_model._flat_graph.connections[key].node + node = compiled_model.flat_graph.connections[key].node assert node is not None - formula_key = node.model._formula_key + formula_key = node.model.formula_key keys = {conn.key for conn in node.connections.values() if conn.key != key} result_connections[key] = [formula_key, keys] diff --git a/tests/utils.py b/tests/utils.py index 516012e5..24b65b0d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -85,8 +85,8 @@ def check_physical_models( ): if check_internals: # Check flat_graphs. - assert pm_1._flat_graph.all_source_keys == pm_2._flat_graph.all_source_keys - assert pm_1._flat_graph.all_target_keys == pm_2._flat_graph.all_target_keys + assert pm_1.flat_graph.all_source_keys == pm_2.flat_graph.all_source_keys + assert pm_1.flat_graph.all_target_keys == pm_2.flat_graph.all_target_keys # Check data stores. for key, value in pm_1.data.items(): From c1d22b57208faf6e22504a84c46fe3332cfb3bcf Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Fri, 27 Dec 2024 10:23:56 +0300 Subject: [PATCH 23/26] some missing changes are added --- mithril/framework/logical/base.py | 24 ++++++++++++++++-------- mithril/framework/logical/primitive.py | 12 ++++++++++-- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 0cf2702c..2804463a 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -62,20 +62,28 @@ @dataclass class ExtendInfo: - model: BaseModel - connections: dict[str, ConnectionType] + _model: BaseModel + _connections: dict[str, ConnectionType] def __post_init__(self): - external_keys = set(self.model.external_keys) - if self.model.canonical_input is not NOT_AVAILABLE: - external_keys.add(self.model.canonical_input.key) - if self.model.canonical_output is not NOT_AVAILABLE: - external_keys.add(self.model.canonical_output.key) + external_keys = set(self._model.external_keys) + if self._model.canonical_input is not NOT_AVAILABLE: + external_keys.add(self._model.canonical_input.key) + if self._model.canonical_output is not NOT_AVAILABLE: + external_keys.add(self._model.canonical_output.key) - for key in self.connections: + for key in self._connections: if key not in external_keys: raise KeyError(f"Key '{key}' is not a valid key for the model!") + @property + def model(self): + return self._model + + @property + def connections(self): + return self._connections + class BaseModel(abc.ABC): # Disposable models only used once for entire training session. diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 919a3395..3ca9b5e3 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -56,8 +56,8 @@ def __init__( name: str | None = None, **kwargs: BaseKey | Tensor | Scalar, ) -> None: - self.formula_key = formula_key - self.grad_formula = formula_key + "_grad" + self._formula_key = formula_key + self._grad_formula = formula_key + "_grad" super().__init__(name=name) # Get shape_templates of TensorTypes and create corresponding shapes. @@ -173,6 +173,14 @@ def __iadd__(self, other: BaseModel): f"Primitive '{self.__class__.__name__}' model can not be extended!" ) + @property + def formula_key(self) -> str: + return self._formula_key + + @property + def grad_formula(self) -> str: + return self._grad_formula + def extract_connection_info( self, name_mappings: dict[BaseModel, str], From 4b99c38e8b570db9a3f0197a0771e8885dc773b5 Mon Sep 17 00:00:00 2001 From: kberat-synnada <97015093+kberat-synnada@users.noreply.github.com> Date: Fri, 27 Dec 2024 10:42:13 +0300 Subject: [PATCH 24/26] Update PhysicalModel _flat_graph attribute --- mithril/framework/physical/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 64bb7658..5be36fa4 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -519,7 +519,7 @@ def _pre_compile( # Set given shapes. self.data_store.set_shapes(shapes) - for node in self._flat_graph.nodes.values(): + for node in self.flat_graph.nodes.values(): conn_data = node.model.conns.get_connection("output") assert conn_data is not None if isinstance(conn_data.metadata.data, Scalar) or ( From 73d8e0b3e689be2dd21a44a97c8ad8653334c26d Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Tue, 31 Dec 2024 14:34:35 +0300 Subject: [PATCH 25/26] type errors are partially cleared --- benchmarks/speed_benchmarks/jax_fns.py | 2 +- .../linear_regression_jax_training.py | 2 +- .../variable_length_many_to_one_lstm.py | 10 +- mithril/backends/backend.py | 13 +- mithril/backends/parallel.py | 2 +- .../with_autograd/common_primitives.py | 98 +++++++-------- .../with_autograd/jax_backend/backend.py | 20 +-- .../backends/with_autograd/jax_backend/ops.py | 8 +- .../with_autograd/jax_backend/parallel.py | 10 +- .../with_autograd/jax_backend/utils.py | 4 +- .../with_autograd/mlx_backend/backend.py | 2 +- .../with_manualgrad/c_backend/backend.py | 2 +- .../with_manualgrad/c_backend/src/array.pyi | 1 + .../with_manualgrad/numpy_backend/ops.py | 4 +- .../with_manualgrad/numpy_backend/ops_grad.py | 4 +- .../with_manualgrad/numpy_backend/utils.py | 2 +- mithril/framework/codegen/c_gen.py | 2 +- mithril/framework/codegen/python_gen.py | 38 +++--- mithril/framework/codegen/torch_gen.py | 2 +- mithril/framework/common.py | 28 +++-- mithril/framework/constraints.py | 4 +- mithril/framework/logical/base.py | 2 +- mithril/framework/logical/primitive.py | 4 +- mithril/framework/physical/data_store.py | 49 ++++---- mithril/framework/physical/flat_graph.py | 72 +++++------ mithril/framework/physical/model.py | 115 ++++++++++-------- mithril/framework/utils.py | 15 ++- mithril/models/models.py | 10 +- mithril/models/train_model.py | 12 +- mithril/utils/dict_conversions.py | 2 +- mypy.ini | 7 +- releases/generate_changelog.py | 2 +- tests/scripts/test_data_store.py | 32 ++--- tests/scripts/test_flatmodel.py | 8 +- tests/utils.py | 4 +- 35 files changed, 307 insertions(+), 285 deletions(-) diff --git a/benchmarks/speed_benchmarks/jax_fns.py b/benchmarks/speed_benchmarks/jax_fns.py index 27423f5d..170e1ccd 100644 --- a/benchmarks/speed_benchmarks/jax_fns.py +++ b/benchmarks/speed_benchmarks/jax_fns.py @@ -66,7 +66,7 @@ def setup(self): def __call__(self, inputs): x = inputs for lyr, actv in zip(self.layers, self.jax_activations, strict=False): - x = lyr(x) # type: ignore + x = lyr(x) x = actv(x) # type: ignore return x diff --git a/examples/model_api/linear_regression_jax_training.py b/examples/model_api/linear_regression_jax_training.py index 058f2205..0769b2b1 100644 --- a/examples/model_api/linear_regression_jax_training.py +++ b/examples/model_api/linear_regression_jax_training.py @@ -58,5 +58,5 @@ for i in range(num_epochs): outputs, gradients = pm.evaluate_all(params) updates, opt_state = optimizer.update(gradients, opt_state) - params = optax.apply_updates(params, updates) # type: ignore + params = optax.apply_updates(params, updates) print(f"Epoch: {i} / {num_epochs} -> ", outputs["final_cost"]) diff --git a/examples/model_api/variable_length_many_to_one_lstm.py b/examples/model_api/variable_length_many_to_one_lstm.py index 1671ed61..335e3ff3 100644 --- a/examples/model_api/variable_length_many_to_one_lstm.py +++ b/examples/model_api/variable_length_many_to_one_lstm.py @@ -67,8 +67,8 @@ target_end = int(input_end + target_lengths[idx]) # NOTE: Pylance sees int, int type arguments but throws an error. - single_input = backend.arange(start, input_end).reshape(-1, input_dim) # type: ignore - single_target = backend.arange(input_end, target_end).reshape(-1, output_dim) # type: ignore + single_input = backend.arange(start, input_end).reshape(-1, input_dim) + single_target = backend.arange(input_end, target_end).reshape(-1, output_dim) single_data = (single_input, single_target) train_data.append(single_data) @@ -150,7 +150,7 @@ # Prepare the test input data. test_input = backend.arange( starting_number, - starting_number + inference_max_input, # type: ignore + starting_number + inference_max_input, ).reshape(-1, input_dim) # Prepare the test data. @@ -172,7 +172,7 @@ # Prepare target values. test_target_values = backend.arange( - starting_number + inference_max_input, # type: ignore + starting_number + inference_max_input, starting_number + inference_max_input + inference_max_target_length, ) @@ -204,4 +204,4 @@ ) # Measure test error. -error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() # type: ignore +error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index bbde448c..3b8b02fd 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -57,11 +57,12 @@ def __init__(self, precision: int = 32, device: str = "cpu") -> None: # setattr(self, key, value) @property - def precision(self): + def precision(self) -> int: return self._precision + #!! @property - def device(self): + def device(self) -> Any: return self._device def get_device(self): @@ -72,11 +73,11 @@ def inf(self) -> DataType | float: raise NotImplementedError("inf is not implemented") @property - def pi(self): + def pi(self) -> float: return math.pi @property - def e(self): + def e(self) -> float: return math.e @property @@ -104,7 +105,7 @@ def to_device( def block_until_ready(self, data: DataType) -> DataType | None: raise RuntimeError("Backend does not support block_until_ready method!") - def empty_cache(self): # noqa: B027 + def empty_cache(self) -> None: # noqa: B027 pass # print("Warning: empty_cache is not supported!") @@ -126,7 +127,7 @@ def cast(self, value: Any) -> Any: return value - def __del__(self): + def __del__(self) -> None: self.empty_cache() @overload diff --git a/mithril/backends/parallel.py b/mithril/backends/parallel.py index ea8451a5..243abb68 100644 --- a/mithril/backends/parallel.py +++ b/mithril/backends/parallel.py @@ -40,7 +40,7 @@ def parallelize( ) -> dict[str, Any]: raise NotImplementedError() - def clean_up(self): + def clean_up(self) -> None: self.callables = dict() self.device_mesh = None self.n_devices = -1 diff --git a/mithril/backends/with_autograd/common_primitives.py b/mithril/backends/with_autograd/common_primitives.py index 50e3a064..98cd6000 100644 --- a/mithril/backends/with_autograd/common_primitives.py +++ b/mithril/backends/with_autograd/common_primitives.py @@ -63,80 +63,80 @@ ] -def greater(left: DataType, right: DataType): +def greater(left: DataType, right: DataType) -> DataType: return left > right -def greater_equal(left: DataType, right: DataType): +def greater_equal(left: DataType, right: DataType) -> DataType: return left >= right -def less(left: DataType, right: DataType): +def less(left: DataType, right: DataType) -> DataType: return left < right -def less_equal(left: DataType, right: DataType): +def less_equal(left: DataType, right: DataType) -> DataType: return left <= right -def equal(left: DataType, right: DataType): - return left == right +def equal(left: DataType, right: DataType) -> DataType: + return left == right # type: ignore -def not_equal(left: DataType, right: DataType): - return left != right +def not_equal(left: DataType, right: DataType) -> DataType: + return left != right # type: ignore -def logical_not(input: DataType): +def logical_not(input: DataType) -> DataType: return ~input -def logical_or(left: DataType, right: DataType): - return left | right +def logical_or(left: DataType, right: DataType) -> DataType: + return left | right # type: ignore -def logical_and(left: DataType, right: DataType): - return left & right +def logical_and(left: DataType, right: DataType) -> DataType: + return left & right # type: ignore -def matrix_multiplication(left: DataType, right: DataType): - return left @ right +def matrix_multiplication(left: DataType, right: DataType) -> DataType: + return left @ right # type: ignore -def add(left: DataType, right: DataType): - return left + right +def add(left: DataType, right: DataType) -> DataType: + return left + right # type: ignore -def subtract(left: DataType, right: DataType): - return left - right +def subtract(left: DataType, right: DataType) -> DataType: + return left - right # type: ignore -def multiplication(left: DataType, right: DataType): - return left * right +def multiplication(left: DataType, right: DataType) -> DataType: + return left * right # type: ignore -def divide(numerator: DataType, denominator: DataType): - return numerator / denominator +def divide(numerator: DataType, denominator: DataType) -> DataType: + return numerator / denominator # type: ignore -def floor_divide(numerator: DataType, denominator: DataType): - return numerator // denominator +def floor_divide(numerator: DataType, denominator: DataType) -> DataType: + return numerator // denominator # type: ignore -def shift_left(input: DataType, shift: DataType): - return input << shift +def shift_left(input: DataType, shift: DataType) -> DataType: + return input << shift # type: ignore -def shift_right(input: DataType, shift: DataType): - return input >> shift +def shift_right(input: DataType, shift: DataType) -> DataType: + return input >> shift # type: ignore -def power(base: DataType, exponent: DataType): - return base**exponent +def power(base: DataType, exponent: DataType) -> DataType: + return base**exponent # type: ignore -def squared_error(input: DataType, target: DataType): - return (input - target) ** 2 +def squared_error(input: DataType, target: DataType) -> DataType: + return (input - target) ** 2 # type: ignore def minus(input: DataType) -> DataType: @@ -148,18 +148,18 @@ def transpose( ) -> DataType: if not axes: return input.T - return input.transpose(*axes) + return input.transpose(*axes) # type: ignore -def swapaxes(input: DataType, axis1: int, axis2: int): +def swapaxes(input: DataType, axis1: int, axis2: int) -> DataType: return input.swapaxes(axis1, axis2) -def square(input: DataType): - return input * input +def square(input: DataType) -> DataType: + return input * input # type: ignore -def buffer(input: DataType): +def buffer(input: DataType) -> DataType: return input @@ -168,18 +168,20 @@ def permute_tensor(input: DataType, indices: DataType) -> DataType: def reshape(input: DataType, shape: tuple[int, ...]) -> DataType: - return input.reshape(shape) + return input.reshape(shape) # type: ignore def item(input: DataType) -> int | float | bool: return input.item() # type: ignore -def tensor_item(input: DataType, index: int | slice | tuple[int | slice, ...]): - return input[index] +def tensor_item( + input: DataType, index: int | slice | tuple[int | slice, ...] +) -> DataType: + return input[index] # type: ignore -def primitive_slice(start: int | None, stop: int | None, step: int | None): +def primitive_slice(start: int | None, stop: int | None, step: int | None) -> slice: return slice(start, stop, step) @@ -187,8 +189,8 @@ def length(input: DataType) -> int: return len(input) -def cartesian_diff(left: DataType, right: DataType): - return left[:, None, :] - right[None, :, :] +def cartesian_diff(left: DataType, right: DataType) -> DataType: + return left[:, None, :] - right[None, :, :] # type: ignore def primitive_embedding(input: DataType, weight: DataType) -> DataType: @@ -218,11 +220,11 @@ def union(*inputs: int | float | tuple[int | float, ...]) -> tuple[int | float, return result -def to_tuple(*args: tuple[int | float | bool, ...]): +def to_tuple(*args: int | float | bool) -> tuple[int | float | bool, ...]: return tuple(args) -def to_list(*args: tuple[int | float | bool, ...]): +def to_list(*args: int | float | bool) -> list[int | float | bool]: return list(args) @@ -291,7 +293,7 @@ def padding_converter_2d( def stride_converter( input: int | PaddingType | tuple[int, int] | None, kernel_size: int | tuple[int, int], -): +) -> int | tuple[int, int] | PaddingType: if input is None: return kernel_size else: @@ -303,7 +305,7 @@ def tuple_converter( | PaddingType | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]], -): +) -> tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] | PaddingType: if isinstance(input, int): return (input, input) else: diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index af371d09..06cf489a 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -75,14 +75,14 @@ def is_manualgrad(self) -> bool: return False @property - def inf(self): + def inf(self) -> float: return jax.numpy.inf @property - def nan(self): + def nan(self) -> float: return jax.numpy.nan - def get_backend_array_type(self): + def get_backend_array_type(self) -> type[jax.Array]: return jax.Array @property @@ -98,7 +98,7 @@ def DataType(self): # noqa: N802 return utils.ArrayType @staticmethod - def get_available_devices(): + def get_available_devices() -> list[str]: """Static method to get a list of available devices. Parameters @@ -112,7 +112,7 @@ def get_available_devices(): def register_primitive(fn: Callable[..., Any]) -> None: JaxBackend.registered_primitives[fn.__name__] = fn - def set_seed(self, seed: int): + def set_seed(self, seed: int) -> None: self.seed = seed self.prng_key = jax.random.PRNGKey(seed) @@ -145,7 +145,7 @@ def block_until_ready(self, data: jax.Array) -> jax.Array | None: def register_callable( self, fn: Callable[..., Any], fn_name: str, jit: bool = False - ): + ) -> None: assert ( self._parallel_manager is not None ), "Parallel manager is not initialized!" @@ -153,7 +153,7 @@ def register_callable( fn_name = str(id(self)) + fn_name return self._parallel_manager.register_callable(fn, fn_name, jit) - def _run_callable(self, *primals: jax.Array, fn_name: str): + def _run_callable(self, *primals: jax.Array, fn_name: str) -> Any: assert ( self._parallel_manager is not None ), "Parallel manager is not initialized!" @@ -161,7 +161,7 @@ def _run_callable(self, *primals: jax.Array, fn_name: str): fn_name = str(id(self)) + fn_name return self._parallel_manager.run_callable(*primals, fn_name=fn_name) - def _create_parallel(self, device_mesh: tuple[int, ...]): + def _create_parallel(self, device_mesh: tuple[int, ...]) -> None: self._parallel_manager = JaxParallel(math.prod(device_mesh), self._device) def array( @@ -540,7 +540,9 @@ def multinomial( return samples - def jit(self, *args: Any, **kwargs: Any): + def jit( + self, *args: Any, **kwargs: Any + ) -> Callable[..., jax.Array | tuple[jax.Array, ...]] | dict[str, jax.Array]: return jax.jit(*args, **kwargs) def grad( diff --git a/mithril/backends/with_autograd/jax_backend/ops.py b/mithril/backends/with_autograd/jax_backend/ops.py index 8c338dca..d204be75 100644 --- a/mithril/backends/with_autograd/jax_backend/ops.py +++ b/mithril/backends/with_autograd/jax_backend/ops.py @@ -424,7 +424,7 @@ def conv2d( stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), -): +) -> jax.Array: _padding_normalized: tuple[tuple[int, int], tuple[int, int]] if is_tuple_int(padding): _padding_normalized = ((padding[0], padding[0]), (padding[1], padding[1])) @@ -451,7 +451,7 @@ def conv2d_bias( stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), -): +) -> jax.Array: return ( conv2d( input=input, @@ -559,7 +559,7 @@ def scaled_dot_product_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: float | int | None = None, -): +) -> jax.Array: if dropout_p != 0.0: raise RuntimeError( "Currently Jax scaled_dot_product_attention only support dropout_p 0" @@ -632,7 +632,7 @@ def cross_entropy( f"Cross entropy got unexpected type for target '{target.dtype}'." ) - return ( + return ( # type: ignore -log(jnp.take_along_axis(input, target[:, None], axis=1)[:, 0]) * _weights[target] ) diff --git a/mithril/backends/with_autograd/jax_backend/parallel.py b/mithril/backends/with_autograd/jax_backend/parallel.py index 2ae2c9bd..901c182e 100644 --- a/mithril/backends/with_autograd/jax_backend/parallel.py +++ b/mithril/backends/with_autograd/jax_backend/parallel.py @@ -32,12 +32,12 @@ def __init__(self, n_devices: int, device: str) -> None: ) super().__init__(n_devices) - def run_callable(self, *primals: jax.Array, fn_name: str): + def run_callable(self, *primals: jax.Array, fn_name: str) -> Any: return self.callables[fn_name](*primals) def parallelize( self, tensor: jax.Array, device_mesh: tuple[int, ...] | None = None - ): + ) -> jax.Array: # Jax reuqires math.prod(device_mesh) == n_devices. To replicate a dimension # call 'replicate' method of Positional Sharding Object. Therefore, we need to # transform user provided device mesh to the one that satisfies the condition, @@ -67,11 +67,13 @@ def parallelize( return jax.device_put(tensor, sharding) - def register_callable(self, fn: Callable[..., Any], fn_name: str, jit: bool): + def register_callable( + self, fn: Callable[..., Any], fn_name: str, jit: bool + ) -> None: if jit: fn = jax.jit(fn) self.callables[fn_name] = fn - def clean_up(self): + def clean_up(self) -> None: self.callables = {} diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index 9c5139b6..367d2e90 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -78,7 +78,7 @@ def robust_power_helper( input1: jax.Array, input2: jax.Array, threshold: jax.Array ) -> jax.Array: def cond_fun(cond: jax.Array, input1: jax.Array, input2: jax.Array) -> jax.Array: - return jax.lax.cond( + return jax.lax.cond( # type: ignore cond, robust_power_under_threshold, robust_power_above_threshold, @@ -284,7 +284,7 @@ def polynomial_features_helper(x: jax.Array, y: jax.Array) -> jax.Array: ) -def get_available_devices(): +def get_available_devices() -> list[str]: backends: set[str] = set(jax._src.xla_bridge.backends()) - set(["interpreter"]) devices = [ f"{backend.replace('METAL','mps')}:{idx}" diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index a80f0c4d..dfa50e8b 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -53,7 +53,7 @@ def is_manualgrad(self) -> bool: return False @property - def inf(self): + def inf(self) -> float: return mx.inf @property diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index 369f1882..3710f7d1 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -29,7 +29,7 @@ class CBackend(Backend[PyArray]): type = "c" SRC_PATH = "mithril/backends/with_manualgrad/c_backend/src" - def __init__(self): + def __init__(self) -> None: self._precision = 32 self._device = "cpu" self.primitive_function_dict = {} diff --git a/mithril/backends/with_manualgrad/c_backend/src/array.pyi b/mithril/backends/with_manualgrad/c_backend/src/array.pyi index fd0a17a4..e36bb5c0 100644 --- a/mithril/backends/with_manualgrad/c_backend/src/array.pyi +++ b/mithril/backends/with_manualgrad/c_backend/src/array.pyi @@ -40,6 +40,7 @@ class PyArray: def __le__(self, other: PyArray) -> PyArray: ... def __and__(self, other: PyArray) -> PyArray: ... def __or__(self, other: PyArray) -> PyArray: ... + def __ror__(self, other: PyArray) -> PyArray: ... def __xor__(self, other: PyArray) -> PyArray: ... def __invert__(self) -> PyArray: ... def __matmul__(self, other: PyArray) -> PyArray: ... diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index cd29c29f..5720a615 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -20,8 +20,8 @@ from typing import Any import numpy as np -import scipy.linalg as slin # type: ignore[import-untyped] -from scipy.special import erf # type: ignore[import-untyped] +import scipy.linalg as slin +from scipy.special import erf from .... import core from ....utils.type_utils import is_tuple_int diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py index d4584aa5..a0e3bf31 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py @@ -17,8 +17,8 @@ from typing import Any import numpy as np -import scipy.linalg as slin # type: ignore[import-untyped] -from scipy.special import erf # type: ignore[import-untyped] +import scipy.linalg as slin +from scipy.special import erf from ....utils.type_utils import is_tuple_int from .ops import hinge_loss, sigmoid, softmax diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index 42d00762..2d8bccbf 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -85,7 +85,7 @@ def write_into_cache[T: np.ndarray[Any, Any] | tuple[Any, ...] | int | float]( else: result = cache[key] # TODO: Resolve here - return result # type: ignore + return result def get_submatrices1d( diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index aba85b5e..d54c5d62 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -277,7 +277,7 @@ def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: return evaluate_grad_fn, used_keys - def _get_backend_path(self): + def _get_backend_path(self) -> str: backend_path = backend.__file__ return backend_path[: backend_path.rindex("/")] diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 97750ffa..5961ce9c 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -21,10 +21,10 @@ from typing import Any, Generic, Literal, Protocol, overload from ...backends.backend import ParallelBackend +from ...core import DataType from ...utils.func_utils import prepare_function_args from ..common import ( DataEvalType, - DataType, EvaluateAllType, EvaluateGradientsType, EvaluateType, @@ -104,7 +104,7 @@ def __init__(self, pm: PhysicalModel[DataType]) -> None: self.globals: list[ast.stmt] = [] self.functions: list[ast.stmt] = [] - def generate_code(self, file_path: str | None = None): + def generate_code(self, file_path: str | None = None) -> None: self.file_path = file_path self.imports += self.generate_imports() self.functions += self.generate_functions() @@ -119,10 +119,10 @@ def generate_code(self, file_path: str | None = None): if file_path is not None: self.write_code(file_path) - def generate_functions(self): + def generate_functions(self) -> list[ast.FunctionDef]: return [self.generate_evaluate()] - def write_code(self, file_path: str): + def write_code(self, file_path: str) -> None: if self.code is None: raise Exception( "Code is not generated yet! Please call generate_code() first." @@ -151,7 +151,7 @@ def exec_generated_code( if self.file_path is not None: module_name = splitext(basename(self.file_path))[0] - module_spec = importlib.util.spec_from_file_location( # type: ignore + module_spec = importlib.util.spec_from_file_location( module_name, self.file_path ) module = importlib.util.module_from_spec(module_spec) # type: ignore @@ -247,7 +247,7 @@ def post_process_fns( return eval_fn, grad_fn, evaluate_all_fn - def import_backend(self): + def import_backend(self) -> ast.ImportFrom: backend = ast.ImportFrom( module="mithril", names=[ @@ -261,7 +261,7 @@ def import_backend(self): return backend - def generate_imports(self): + def generate_imports(self) -> list[ast.stmt]: imports: list[ast.stmt] = [] # Add import primitive functions imports.append( @@ -295,7 +295,9 @@ def generate_imports(self): return imports - def get_primitive_details(self, output_key: str): + def get_primitive_details( + self, output_key: str + ) -> tuple[PrimitiveModel, list[str], list[str]]: model = self.pm.flat_graph.get_model(output_key) global_input_keys = self.pm.flat_graph.get_source_keys(output_key) @@ -311,7 +313,7 @@ def call_primitive( g_input_keys: list[str], output_key: str, formula_key: str, - ): + ) -> tuple[ast.Assign, set[str]]: generated_fn, used_keys = self.create_primitive_call( fn, l_input_keys, g_input_keys ) @@ -321,9 +323,9 @@ def call_primitive( if formula_key in self.pm.backend.array_creation_funcs: self.add_partial_function(formula_key) - return ast.Assign(targets, generated_fn), used_keys | _used_keys # type: ignore + return ast.Assign(targets, generated_fn), used_keys | _used_keys - def generate_evaluate(self): + def generate_evaluate(self) -> ast.FunctionDef: input_body: list[ast.stmt] = [] function_body: list[ast.stmt] = [] return_values: list[ast.expr] = [] @@ -424,7 +426,9 @@ def generate_evaluate(self): return ast.fix_missing_locations(func_def) - def append_inputs(self, input_body: list[ast.stmt], key: str, dict_type: str): + def append_inputs( + self, input_body: list[ast.stmt], key: str, dict_type: str + ) -> None: # In manual_grad type backends, cache contains all the required # data (local variables and outputs) for the corresponding function. # So if the key is not directly an output of a function get it from @@ -538,7 +542,7 @@ def create_primitive_call_targets( return targets, {target_name} - def add_partial_function(self, formula_key: str): + def add_partial_function(self, formula_key: str) -> None: if formula_key in self.defined_partial_fns: return @@ -561,7 +565,7 @@ def create_gradient_fn( # raw_evaluate_grad_fn: ManualGradWrapperFn[DataType] | None, raw_evaluate_fn: RawEvaluateType[DataType], raw_evaluate_grad_fn: ManualGradWrapperFn[DataType] | None, - ): + ) -> tuple[ManualGradWrapperFn[DataType], RawEvaluateType[DataType]]: fn_all: EvaluateAllType[DataType] grad_fn: EvaluateGradientsType[DataType] if not self.pm.backend.is_manualgrad: @@ -572,20 +576,20 @@ def create_gradient_fn( include_output=False, ) # Fix fn_all for mlx support!! - fn_all = partial( + fn_all = partial( # type: ignore self.compute_gradients, raw_evaluate_fn=raw_evaluate_fn, cache=self.pm.data_store.data_values, include_output=True, ) - return grad_fn, fn_all + return grad_fn, fn_all # type: ignore else: assert raw_evaluate_grad_fn is not None, "Gradient function is not defined!" fn_all = partial(raw_evaluate_grad_fn, include_output=True) # type: ignore grad_fn = partial(raw_evaluate_grad_fn, include_output=False) # type: ignore - return grad_fn, fn_all + return grad_fn, fn_all # type: ignore @overload def compute_gradients( diff --git a/mithril/framework/codegen/torch_gen.py b/mithril/framework/codegen/torch_gen.py index 1b326de2..b3298a0a 100644 --- a/mithril/framework/codegen/torch_gen.py +++ b/mithril/framework/codegen/torch_gen.py @@ -40,7 +40,7 @@ def call_primitive( g_input_keys: list[str], output_key: str, formula_key: str, - ): + ) -> tuple[ast.Assign, set[str]]: generated_fn, used_keys = self.create_primitive_call( fn, l_input_keys, g_input_keys ) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index f5257773..ad44731a 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -195,6 +195,7 @@ class KeyType(Enum): | tuple[Any, ...] | list[Any] | dict[Any, Any] + | Mapping[Any, Any] | Constant | slice | PaddingType @@ -212,6 +213,7 @@ class KeyType(Enum): | tuple[Any, ...] | list[Any] | dict[Any, Any] + | Mapping[Any, Any] | bool | None | EllipsisType @@ -291,7 +293,7 @@ def __call__( def update_equivalence_table( item1: ItemType, item2: ItemType, lookup_table: dict[ItemType, set[ItemType]] -): +) -> None: item_set1 = lookup_table.get(item1) item_set2 = lookup_table.get(item2) if item_set1 is None and item_set2 is None: @@ -321,7 +323,7 @@ class ConstraintSolver: default_factory=lambda: {} ) - def __call__(self, updates: Updates): + def __call__(self, updates: Updates) -> None: self.update_shapes(updates) solved_constrs: set[Constraint] = set() for constr_type in UpdateType: @@ -332,7 +334,7 @@ def _solver_loop( constraint_type: UpdateType, updates: Updates, solved_constraints: set[Constraint], - ): + ) -> None: constraints = updates.constraints[constraint_type] while constraints: constr = constraints.pop() @@ -370,14 +372,14 @@ def _solver_loop( constraints.discard(constr) @staticmethod - def _combine_nodes(updates: Updates): + def _combine_nodes(updates: Updates) -> None: # Check if any node could be reduced after variadic updates add into # node_updates field. while updates.node_updates: node = updates.node_updates.pop() updates |= node.combine() - def _reduce_uniadic_referees(self, updates: Updates): + def _reduce_uniadic_referees(self, updates: Updates) -> None: while updates.uniadic_updates: uni = updates.uniadic_updates.pop() uni_val = uni.value @@ -445,7 +447,7 @@ def _add_sublists( return updates - def clear(self): + def clear(self) -> None: self.symbol_store = {} self.constraint_map = {} self.empty_node = None @@ -479,7 +481,7 @@ def _delete_node(remaining: ShapeNode, deleted: ShapeNode) -> Updates: deleted.reprs = [] return updates - def update_shapes(self, updates: Updates): + def update_shapes(self, updates: Updates) -> None: deletion_nodes: dict[ShapeNode, set[ShapeNode]] = {} # Note that update can be tuple also. First element of update # is always Tensor | Scalar. So this is the case, get first element @@ -732,7 +734,7 @@ def match(self, other: BaseData[T]) -> Updates: self.finalize_match(other) return updates - def set_value(self, value: AllValueType) -> Updates: # type: ignore[override] + def set_value(self, value: AllValueType) -> Updates: updates = Updates() if self.value is not TBD and self.value != value: raise ValueError( @@ -1375,7 +1377,7 @@ def get_shape_node(self, key: str) -> ShapeNode: return data.shape def set_value(self, con: ConnectionData, value: MainValueType): - self.get_data(con.key).set_value(value) # type: ignore + self.get_data(con.key).set_value(value) def extract_metadata(self, key: str | Connection) -> IOHyperEdge: if isinstance(key, Connection): @@ -2721,7 +2723,7 @@ def __call__(self, keys: list[Scalar | Tensor]) -> ConstrainResultType: def add_post_process(self, fn: ConstraintFunctionType): self.post_processes.add(fn) - def create_post_constraints(self): + def create_post_constraints(self) -> set[Constraint]: constraints: set[Constraint] = set() for fn in self.post_processes: constraints.add(Constraint(fn, self.type)) @@ -2879,7 +2881,7 @@ def compile( # adjust the table accordingly self._adjust_table() # calculate total table width - table_width = reduce( # type: ignore + table_width = reduce( partial(add_lengths, const=(len(row) for row in row_sep)), self.each_row_width, ) @@ -2930,7 +2932,7 @@ def compile( self.header_str = header_str self.cell_str = cell_str - def display(self): + def display(self) -> None: """Prints the table""" print(self.header_str) print(self.cell_str) @@ -2941,7 +2943,7 @@ def construct_subtable_row( arg_max_lengths: list[int], adjustments: list[str], *args: list[list[str]], - ): + ) -> list[str]: # Constructs subtable with given args subtable_list: list[str] = [] elems: tuple[list[str], ...] diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index 7c99c02b..8e423b89 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -105,7 +105,7 @@ # Below functions are used in various constraints. def prod_fn(a: int | Uniadic, b: int | Uniadic) -> int: - return (a if isinstance(a, int) else a.value) * ( # type: ignore + return (a if isinstance(a, int) else a.value) * ( b if isinstance(b, int) else b.value ) @@ -480,7 +480,7 @@ def scalar_item_reduce_input_type( inner_type.append(output_type) if arg is tuple: inner_type.append(...) - possible_types.append(arg[*inner_type]) # type: ignore + possible_types.append(arg[*inner_type]) return create_union_type(*possible_types) elif isinstance(input_type, GenericAlias): input_origin = input_type.__origin__ diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 2804463a..831628a8 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -327,7 +327,7 @@ def _set_value(self, key: ConnectionData, value: MainValueType | str) -> Updates if key.key not in self.conns.input_keys: raise ValueError("Values of internal and output keys cannot be set.") # Data is scalar, set the value directly. - return key.metadata.data.set_value(value) # type: ignore + return key.metadata.data.set_value(value) def set_shapes( self, config: ShapesType | None = None, **kwargs: ShapeTemplateType diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 3ca9b5e3..2b4a4f21 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -97,7 +97,7 @@ def __init__( assert isinstance(value, BaseKey) _value = Tensor( shape=shapes[key].node, - possible_types=get_mytensor_subtype(value_type), # type: ignore + possible_types=get_mytensor_subtype(value_type), value=value.value, # type: ignore interval=value.interval, ) @@ -107,7 +107,7 @@ def __init__( else: _value = Scalar( possible_types=value_type, # type: ignore - value=value.value, # type: ignore + value=value.value, ) conn_data = self.create_connection(IOHyperEdge(_value), key) diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index ee1f9593..c1d427b4 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -17,14 +17,13 @@ from typing import Any, Generic, TypeGuard from ...backends.backend import Backend -from ...core import DataType, data_types, epsilon_table +from ...core import Constant, DataType, data_types, epsilon_table from ...utils.func_utils import is_make_array_required, prepare_function_args from ...utils.utils import BiMap from ..common import ( TBD, AllValueType, Connection, - Constant, ConstraintSolver, DataEvalType, MainValueInstance, @@ -55,7 +54,7 @@ def __init__( self.graph: FlatGraph[DataType] = graph self.backend: Backend[DataType] = backend self.inference = inference - self._intermediate_non_differentiables: BiMap[str, Tensor | Scalar] = BiMap() + self.intermediate_non_differentiables: BiMap[str, Tensor | Scalar] = BiMap() # type: ignore self._runtime_static_keys: set[str] = set() self._unused_keys: set[str] = set() # Final tensor values of data store. @@ -63,11 +62,11 @@ def __init__( self.constraint_solver: ConstraintSolver = deepcopy(solver, memo=memo) @property - def all_data(self): + def all_data(self) -> dict[str, Tensor | Scalar]: return self._all_data @property - def cached_data(self): + def cached_data(self) -> DataEvalType[DataType]: return self.data_values @property @@ -86,19 +85,19 @@ def unused_keys(self) -> set[str]: def is_scalar_type(t: Any) -> TypeGuard[MainValueType]: return not isinstance(t, tuple(data_types)) - def remove_keys_from_store(self, keys: set[str]): + def remove_keys_from_store(self, keys: set[str]) -> None: keys -= set(self.graph.output_keys) for key in keys: self.remove_key_from_store(key, label_as_unused=False, hard_remove=True) def remove_key_from_store( self, key: str, label_as_unused: bool = True, hard_remove: bool = False - ): + ) -> None: if key in self.data_values: self.data_values.pop(key) # type: ignore self._runtime_static_keys.discard(key) - if key in self._intermediate_non_differentiables: - self._intermediate_non_differentiables.pop(key) + if key in self.intermediate_non_differentiables: + self.intermediate_non_differentiables.pop(key) if label_as_unused: self._unused_keys.add(key) @@ -110,7 +109,7 @@ def remove_key_from_store( self._all_data.pop(key) self._clear_constraints(key) - def _clear_constraints(self, key: str): + def _clear_constraints(self, key: str) -> None: if key not in self._all_data: return @@ -121,16 +120,16 @@ def _clear_constraints(self, key: str): self._all_data[source_key].shape_constraints -= shape_constraints self._all_data[source_key].type_constraints -= type_constraints - def _update_cached_data(self, updated_data: Updates) -> set[str]: + def update_cached_data(self, updated_data: Updates) -> set[str]: # If any data value is found by shape inference algorithms # transfer this data in cached_data. transferred_keys: set[str] = set() updated_inter_data = ( updated_data.value_updates - & self._intermediate_non_differentiables.inverse.keys() + & self.intermediate_non_differentiables.inverse.keys() ) for data in updated_inter_data: - key = self._intermediate_non_differentiables.inverse[data] + key = self.intermediate_non_differentiables.inverse[data] if key in self.data_values or data.value is not TBD: if key in self.data_values: raise KeyError( @@ -141,10 +140,10 @@ def _update_cached_data(self, updated_data: Updates) -> set[str]: self._set_data_value(key, data) transferred_keys.add(key) for key in transferred_keys: - self._intermediate_non_differentiables.pop(key) + self.intermediate_non_differentiables.pop(key) return transferred_keys - def _set_data_value(self, key: str, data: Tensor | Scalar): + def _set_data_value(self, key: str, data: Tensor | Scalar) -> None: value: DataType | AllValueType = data.value assert not isinstance(value, ToBeDetermined) if isinstance(data, Tensor): @@ -155,7 +154,7 @@ def _set_data_value(self, key: str, data: Tensor | Scalar): self.data_values[key] = value # type: ignore - def _infer_unused_keys(self, key: str): + def _infer_unused_keys(self, key: str) -> None: # Infers unused keys when "key" is set as static. output_keys = self.graph.output_keys queue = set(self.graph.get_source_keys(key, True)) @@ -195,11 +194,11 @@ def set_shapes( updates |= data.shape.set_values(value) self.constraint_solver(updates) # Some intermediate values may be calculated, update cached data. - new_statics = self._update_cached_data(updates) + new_statics = self.update_cached_data(updates) for key in new_statics: self._infer_unused_keys(key) - def update_data(self, data: dict[str, Tensor | Scalar]): + def update_data(self, data: dict[str, Tensor | Scalar]) -> None: if data.keys() & self._all_data.keys(): raise Exception("Some keys are already in data store!") self._all_data |= data @@ -225,7 +224,7 @@ def update_data(self, data: dict[str, Tensor | Scalar]): if value.value is not TBD: self._set_data_value(key, value) else: - self._intermediate_non_differentiables[key] = value + self.intermediate_non_differentiables[key] = value def set_static_keys( self, @@ -272,11 +271,11 @@ def add_static_data( if isinstance(data, Tensor): assert not isinstance(value, MainValueInstance) # Find shape of tensor and set. - shape = list(value.shape) + shape = list(value.shape) # type: ignore updates |= data.shape.set_values(shape) # Find type of tensor and set. val_type: type[bool] | type[int] | type[float] - data_dtype = str(value.dtype) + data_dtype = str(value.dtype) # type: ignore # Check value type is OK, and update type accordinly. if "bool" in data_dtype: val_type = bool @@ -290,7 +289,7 @@ def add_static_data( "Only float, int or bool types are accepted." ) updates |= data.set_type(val_type) - elif isinstance(data, Scalar) and self.is_scalar_type(value): + elif self.is_scalar_type(value): updates |= data.set_value(value) else: raise ValueError( @@ -298,16 +297,16 @@ def add_static_data( f"the type of data: {type(data)}!" ) self.data_values[key] = value # type: ignore - self._intermediate_non_differentiables.pop(key, None) + self.intermediate_non_differentiables.pop(key, None) if ( - key not in self._intermediate_non_differentiables + key not in self.intermediate_non_differentiables and key in self.runtime_static_keys ): self._runtime_static_keys.remove(key) # Finally update cached_data, infer unused keys and # return newly added static keys. self.constraint_solver(updates) - statics = self._update_cached_data(updates) | updated_keys + statics = self.update_cached_data(updates) | updated_keys for static in statics: self._infer_unused_keys(static) diff --git a/mithril/framework/physical/flat_graph.py b/mithril/framework/physical/flat_graph.py index 0a1291c4..72a98768 100644 --- a/mithril/framework/physical/flat_graph.py +++ b/mithril/framework/physical/flat_graph.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass from ...core import DataType, GenericDataType @@ -52,7 +52,7 @@ class Connection: target_keys: list[str] connections: set[Connection] - def __hash__(self): + def __hash__(self) -> int: return hash(id(self)) @@ -103,7 +103,7 @@ def __init__(self, input_keys: set[str], output_keys: set[str]) -> None: self.value_table: dict[str, DataType | ValueType] = {} @property - def hanging_keys(self): + def hanging_keys(self) -> set[str]: hanging_keys = (self.all_target_keys - self.all_source_keys) | set( self.connections.keys() ) - self.all_target_keys - self.all_source_keys @@ -111,22 +111,22 @@ def hanging_keys(self): return hanging_keys - set(self.output_dict.values()) @property - def input_keys(self): + def input_keys(self) -> set[str]: return set(self._input_keys) @property - def output_keys(self): + def output_keys(self) -> set[str]: return set(self.output_dict.keys()) @property - def all_keys(self): + def all_keys(self) -> set[str]: return ( set(self.connections.keys()) | set(self.output_dict.keys()) | set(self.output_dict.values()) ) - def add_value(self, model: PrimitiveModel, keys: dict[str, str]): + def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> None: output_key = keys[PrimitiveModel.output_key] keys = { key: self._temp_connection_info.get(value, value) @@ -178,7 +178,7 @@ def add_value(self, model: PrimitiveModel, keys: dict[str, str]): self._update_all_source_keys() self._update_all_target_keys() - def collapse_model_keys(self, output_key: str, new_reference_key: str): + def collapse_model_keys(self, output_key: str, new_reference_key: str) -> None: # If a model removed, the models that uses the output of the removed model # should be updated with the new reference key. for key, value in self._temp_connection_info.items(): @@ -197,7 +197,7 @@ def update_output_keys(self, output_key: str, new_reference_key: str) -> bool: return True @property - def topological_order(self): + def topological_order(self) -> list[str]: return self._topological_order @property @@ -208,15 +208,13 @@ def all_target_keys(self) -> set[str]: def all_source_keys(self) -> set[str]: return self._all_source_keys - def _update_topological_order(self): + def _update_topological_order(self) -> None: self._topological_order = [ node.connections[PrimitiveModel.output_key].key for node in self.nodes.values() - if node.model is not None - or node.connections[PrimitiveModel.output_key].key in self.output_keys ] - def _update_all_source_keys(self): + def _update_all_source_keys(self) -> None: self._all_source_keys = { conn.key for item in self.nodes.values() @@ -224,7 +222,7 @@ def _update_all_source_keys(self): if key != "output" } - def _update_all_target_keys(self): + def _update_all_target_keys(self) -> None: self._all_target_keys = { conn.key for item in self.nodes.values() @@ -232,7 +230,7 @@ def _update_all_target_keys(self): if key == "output" } - def _update_connection_keys(self, connection: Connection): + def _update_connection_keys(self, connection: Connection) -> None: source_keys: list[str] = [] target_keys: list[str] = [] @@ -243,7 +241,7 @@ def _update_connection_keys(self, connection: Connection): key = conn.key source_keys.append(key) - def get_target_keys(connection: Connection): + def get_target_keys(connection: Connection) -> list[str]: target_keys: list[str] = [] for conn in connection.connections: target_keys.append(conn.key) @@ -271,27 +269,27 @@ def get_target_keys(connection: Connection): connection.target_keys = list(target_keys) connection.source_keys = list(source_keys) - def get_model(self, key) -> PrimitiveModel: + def get_model(self, key: str) -> PrimitiveModel: conn = self.connections.get(key, None) if conn is None or conn.node is None: raise ValueError(f"Model not found for key: {key}") return conn.node.model - def get_model_out_key(self, model: PrimitiveModel): + def get_model_out_key(self, model: PrimitiveModel) -> str | None: node = self.nodes.get(model, None) if node is None: return None return node.connections[PrimitiveModel.output_key].key - def get_model_outer_key(self, model: PrimitiveModel, inner_key: str): + def get_model_outer_key(self, model: PrimitiveModel, inner_key: str) -> str: return self.nodes[model].connections[inner_key].key - def get_model_connections(self, model: PrimitiveModel): + def get_model_connections(self, model: PrimitiveModel): # type: ignore return self.nodes[model].connections.values() - def get_connection(self, key: str): - return self.connections.get(key, None) + def get_connection(self, key: str) -> Connection | None: + return self.connections.get(key) def get_source_keys(self, key: str, include_outputs: bool = False) -> list[str]: source_keys: list[str] = [] @@ -342,10 +340,7 @@ def _is_duplicate( node: Node, data: dict[str, Tensor | Scalar], constant_keys: Mapping[str, DataType | MainValueType], - ): - if node.model is None: - return - + ) -> Connection | None: # Model id is a unique key for unique operation model_id: list[str] = [] for key, conn in node.connections.items(): @@ -369,7 +364,7 @@ def _is_duplicate( elif self.is_tensor_type(ref_value) and self.is_tensor_type(value): is_equal = ( id(ref_value) == id(value) - or ref_value.shape == value.shape + or ref_value.shape == value.shape # type: ignore and (ref_value == value).all().item() # type: ignore ) else: @@ -392,8 +387,9 @@ def _is_duplicate( return self.unique_model_table[final_model_id] self.unique_model_table[final_model_id] = node.connections["output"] + return None - def _prune_node(self, node: Node, conn: Connection): + def _prune_node(self, node: Node, conn: Connection) -> None: self.collapse_model_keys(node.connections["output"].key, conn.key) # Update source and target keys of node connections @@ -424,15 +420,14 @@ def _prune_node(self, node: Node, conn: Connection): ) not in self.output_keys and key in self._all_target_keys: self._all_target_keys.remove(key) - if node.model is not None: - self.nodes.pop(node.model) + self.nodes.pop(node.model) self._update_connection_keys(conn) self._update_all_source_keys() self._update_all_target_keys() self._update_topological_order() - def _remove_node(self, node: Node): + def _remove_node(self, node: Node) -> None: connections = set(node.connections.values()) output_conn = node.connections[PrimitiveModel.output_key] @@ -444,12 +439,11 @@ def _remove_node(self, node: Node): self._update_connection_keys(conn) self._remove_conn(output_conn) - if node.model is not None: - self.nodes.pop(node.model) + self.nodes.pop(node.model) self._update_topological_order() - def _remove_conn(self, conn: Connection): + def _remove_conn(self, conn: Connection) -> None: self.connections.pop(conn.key, None) # Remove connection from other connections @@ -458,13 +452,13 @@ def _remove_conn(self, conn: Connection): if conn.key in conn_.target_keys: conn_.target_keys.remove(conn.key) - if conn.key in self._all_source_keys: # and conn.key not in self.alias_map: + if conn.key in self._all_source_keys: self._all_source_keys.remove(conn.key) - if conn.key in self._all_target_keys: # and conn.key not in self.alias_map: + if conn.key in self._all_target_keys: self._all_target_keys.remove(conn.key) - def remove_key(self, key: str): + def remove_key(self, key: str) -> None: if key in self.output_dict: self.output_dict.pop(key) @@ -475,7 +469,7 @@ def remove_key(self, key: str): def infer_ignore_step( self, key: str, keys: set[str], queue: set[str], from_source: bool - ): + ) -> None: forward_key_fn: Callable[[str, bool], list[str]] if from_source: forward_key_fn = self.get_target_keys @@ -497,5 +491,5 @@ def infer_ignore_step( keys.add(value) queue.add(value) - def get_models(self): + def get_models(self) -> Iterable[PrimitiveModel]: return self.nodes.keys() diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 5be36fa4..3078addd 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import warnings from collections.abc import Callable, Mapping, Sequence @@ -27,6 +29,7 @@ TBD, Connection, ConnectionData, + ConnectionType, DataEvalType, EvaluateAllType, EvaluateGradientsType, @@ -97,7 +100,7 @@ def __init__( # TODO: Remove wrapping with Model in the future. model = deepcopy(model) extend_info = model() - model_keys = {} + model_keys: dict[str, ConnectionType] = {} for key in model.external_keys: value = extend_info.connections.get(key, NOT_GIVEN) # NOTE: Do not set default value if it is given in constant_keys. @@ -140,7 +143,7 @@ def __init__( ].name key_origin = model.canonical_output.metadata.key_origin if key_origin != current_name: - while key_origin in flat_model.assigned_edges: + while key_origin in flat_model.assigned_names: key_origin = f"_{key_origin}" self._output_keys.add(key_origin) @@ -232,7 +235,9 @@ def __init__( if self.backend.backend_type == "numpy": cache_name = "_".join([mappings[output], p_model.cache_name]) mappings["cache"] = cache_name - cache_value: dict | None = None if self.inference else dict() + cache_value: DataEvalType[DataType] | None = ( + None if self.inference else dict() + ) # Create A object for caches in manualgrad backend. cache_scalar = Scalar(dict | None, cache_value) self.data_store.update_data({cache_name: cache_scalar}) @@ -271,9 +276,9 @@ def __init__( def __call__( self, - params: dict[str, DataType] | None = None, - data: Mapping[str, DataType | MainValueType] | None = None, - ): + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + ) -> DataEvalType[DataType]: return self.evaluate(params=params, data=data) def _convert_key(self, model: BaseModel, key: str | Connection) -> str: @@ -395,7 +400,7 @@ def get_shapes( ) @property - def data(self): + def data(self) -> dict[str, Tensor | Scalar]: return self.data_store._all_data @property @@ -403,14 +408,16 @@ def shapes(self) -> _ShapesType: return self.get_shapes() @property - def output_keys(self): + def output_keys(self) -> list[str]: return sorted(self._output_keys) @property - def input_keys(self): + def input_keys(self) -> set[str]: return self._input_keys - def _infer_differentiability(self, model: PrimitiveModel, dag: dict[str, str]): + def _infer_differentiability( + self, model: PrimitiveModel, dag: dict[str, str] + ) -> None: # Infer output differentiability only for the models # that have a Tensor type output. if isinstance(model.output.metadata.data, Tensor): @@ -506,7 +513,7 @@ def _pre_compile( data_keys: set[str], shapes: PhysicalShapeType, jacobian_keys: set[str], - ): + ) -> None: if jacobian_keys and self.backend.is_manualgrad: raise Exception( "Jacobians are only calculated for the backends that have " @@ -546,16 +553,16 @@ def _pre_compile( logical_id = reverse_data_memo[pruned_data] self.data_store.data_memo[logical_id] = remained_data - updates |= remained_data.match(pruned_data) + updates |= remained_data.match(pruned_data) # type: ignore self.data[key] = remained_data - for value in self.data_store._intermediate_non_differentiables.inverse: + for value in self.data_store.intermediate_non_differentiables.inverse: # there can exist some inferred intermediate scalar keys in logical model. # find those keys and add to cached datas if isinstance(value, Scalar) and value.value is not TBD: updates.add(value) - self.data_store._update_cached_data(updates) + self.data_store.update_cached_data(updates) self.data_store.constraint_solver(updates) @@ -606,7 +613,7 @@ def generate_functions( ) self._generated_evaluate_all_fn: EvaluateAllType[DataType] | None = eval_all_fn - def create_jacobian_fn(self, generated_fn: Callable): + def create_jacobian_fn(self, generated_fn: Callable): # type: ignore # TODO: Fix this method to make it picklable! if self.backend.is_manualgrad: raise ( @@ -616,10 +623,10 @@ def create_jacobian_fn(self, generated_fn: Callable): ) # TODO: Consider to JIT this function. - def multiplier(x, y): - return x * y + def multiplier(x, y): # type: ignore + return x * y # type: ignore - def jacobian_fn( + def jacobian_fn( # type: ignore inputs: dict[str, DataType], data: dict[str, DataType] | None = None ): # Function for calculating jacobians for the requested @@ -628,10 +635,10 @@ def jacobian_fn( if data is None: data = {} - def jacobian_wrapper(input, output): - total_inputs = inputs | input + def jacobian_wrapper(input, output): # type: ignore + total_inputs = inputs | input # type: ignore - return generated_fn(params=total_inputs, data=data)[output] + return generated_fn(params=total_inputs, data=data)[output] # type: ignore jacobians: dict[str, dict[str, DataType]] = {} @@ -644,7 +651,7 @@ def jacobian_wrapper(input, output): jacobians[out] = {} # Iterate over all trainable inputs. - jacobian_par_fn = jacobian_method(partial(jacobian_wrapper, output=out)) + jacobian_par_fn = jacobian_method(partial(jacobian_wrapper, output=out)) # type: ignore for key in inputs: # if all(isinstance(dim, int) for dim in self.shapes[out]) and all( @@ -662,8 +669,9 @@ def jacobian_wrapper(input, output): # for wide Jacobian matrices where output dimensionalitiy # is lower than input dimensionality. # jacfwd is more efficient in oppisite condition. - cond = reduce(multiplier, out_shp) >= reduce( - multiplier, key_shp + cond = reduce(multiplier, out_shp) >= reduce( # type: ignore + multiplier, # type: ignore + key_shp, ) jacobian_method = [self.backend.jacrev, self.backend.jacfwd][ # type: ignore cond @@ -747,9 +755,9 @@ def infer_ignore( def _calculate_parameters( self, - name_mappings: dict[Model, str], + name_mappings: dict[BaseModel, str], data_to_key_map: dict[Tensor | Scalar, list[str]] | None = None, - ): + ) -> tuple[dict[str, tuple[dict[str, str], dict[str, str]]], str]: total_params: int = 0 seen_data: set[Tensor] = set() exact_param_status: bool = True @@ -821,7 +829,7 @@ def _print_model_info( total_params: str, data_to_key_map: dict[Tensor | Scalar, list[str]], model: BaseModel | None = None, - ): + ) -> None: # Find constant inputs of the model. pm_constant_input_keys = ( self._input_keys - self.data_store.unused_keys @@ -889,7 +897,7 @@ def summary( alternative_shapes: bool = False, print_info: bool = True, name: str | None = None, - ): + ) -> None: uni_keys: dict[UniadicRecord, str] = dict() var_keys: dict[Variadic, str] = dict() if model is None and depth != 0: @@ -909,13 +917,9 @@ def summary( type_info = None # Extract all summary information - dag: list[PrimitiveModel] | dict[BaseModel, dict[str, ConnectionData]] + dag: list[BaseModel] | dict[BaseModel, dict[str, ConnectionData]] if model is not None: - if isinstance(model, PrimitiveModel): - dag = [model] - elif isinstance(model, Model): - dag = model.dag - + dag = model.dag if isinstance(model, Model) else [model] name_mappings = define_unique_names(dag) conn_info = model.extract_connection_info( name_mappings, data_to_key_map, self.data_store.data_memo @@ -930,7 +934,7 @@ def summary( all_models.remove(unused_model.node.model) name_mappings = define_unique_names(all_models) - conn_info = self.extract_connection_info(name_mappings) + conn_info = self.extract_connection_info(name_mappings) # type: ignore model_shapes: dict[str, _ShapesType] = { sub_model_name: self.get_shapes( @@ -941,7 +945,8 @@ def summary( # calculate all key parameters and total parameters param_info, total_parameters = self._calculate_parameters( - name_mappings, data_to_key_map + name_mappings, + data_to_key_map, # type: ignore ) if print_info: @@ -964,7 +969,7 @@ def summary( table = get_summary( conns=conn_info, name=name, - shape=shape_info, + shape=shape_info, # type: ignore types=type_info, params=param_info, ) @@ -987,7 +992,7 @@ def summary( def extract_connection_info( self, name_mappings: dict[PrimitiveModel, str] | None = None - ): + ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: if name_mappings is None: name_mappings = define_unique_names(self.flat_graph.get_models()) conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} @@ -1084,7 +1089,7 @@ def _replace_with_primitive( ) primitive.parent = model.parent - p_key_mappings = {} + p_key_mappings: dict[str, str] = {} # for key in model._input_keys | model.output_keys: for key in model.external_keys: if key[0] != "$": @@ -1151,17 +1156,17 @@ class Name: name: str origin: str - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Name): return self.name == other.name if isinstance(other, str): return self.name == other return False - def startswith(self, prefix: str): + def startswith(self, prefix: str) -> bool: return self.name.startswith(prefix) @@ -1217,7 +1222,7 @@ def external_keys(self) -> set[str]: """ return set(self.external_mapping.values()) - def rename_key(self, source_name: str, target_name: str): + def rename_key(self, source_name: str, target_name: str) -> None: """ Rename a key from source_name to target_name. @@ -1234,7 +1239,7 @@ def rename_key(self, source_name: str, target_name: str): self._update_defined_names(source_name, target_name) - def _update_defined_names(self, old_key: str, new_key: str): + def _update_defined_names(self, old_key: str, new_key: str) -> None: old_name = self.assigned_names[old_key] if old_name.origin in self.key_origins: if self.key_origins[old_name.origin] == 0: @@ -1251,7 +1256,7 @@ def _update_defined_names(self, old_key: str, new_key: str): for key, value in self._external_mapping.items() } - def _name_externals(self): + def _name_externals(self) -> None: external_keys = list(self.model.conns.input_keys) + list( self.model.conns.output_keys ) @@ -1329,7 +1334,7 @@ def generate_keys( model: BaseModel, mappings: dict[str, str] | None = None, parent_name: str = "", - ): + ) -> None: """ Generate keys for the model. @@ -1355,7 +1360,7 @@ def generate_keys( def _process_primitive_model( self, model: PrimitiveModel, mappings: dict[str, str], parent_name: str - ): + ) -> None: """ Process a primitive model. @@ -1389,7 +1394,9 @@ def _process_primitive_model( self.used_edges.add(output_edge) self._check_for_queue(output_edge) - def _process_model(self, model: Model, mappings: dict[str, str], parent_name: str): + def _process_model( + self, model: Model, mappings: dict[str, str], parent_name: str + ) -> None: submodel_names = model.get_unique_submodel_names() for m, value in model.dag.items(): @@ -1421,7 +1428,7 @@ def _process_model(self, model: Model, mappings: dict[str, str], parent_name: st self.generate_keys(m, name_mapping, parent_name=name) - def _check_for_queue(self, hyperedge: IOHyperEdge): + def _check_for_queue(self, hyperedge: IOHyperEdge) -> None: if hyperedge in self.queued_models: for m, mappings, parent_name in self.queued_models[hyperedge]: if self._is_primitive_ready(m): @@ -1429,7 +1436,7 @@ def _check_for_queue(self, hyperedge: IOHyperEdge): m, mappings=mappings, parent_name=parent_name ) - def _is_primitive_ready(self, model: PrimitiveModel): + def _is_primitive_ready(self, model: PrimitiveModel) -> bool: """ Check if a primitive model is ready to be processed. @@ -1447,7 +1454,7 @@ def _is_primitive_ready(self, model: PrimitiveModel): def _add_primitive_to_queue( self, model: PrimitiveModel, mappings: dict[str, str], parent_name: str - ): + ) -> None: """ Add a primitive model to the queue. @@ -1498,7 +1505,7 @@ def _create_name(self, name: str, key_origin: str) -> Name: self.assigned_names[name] = new_name return new_name - def _rebase_names(self): + def _rebase_names(self) -> None: """ Rebase the names to remove unnecessary suffixes. """ @@ -1510,10 +1517,10 @@ def _rebase_names(self): self.assigned_names[name].name = base_name self.assigned_names[base_name] = self.assigned_names.pop(name) - def __iter__(self): + def __iter__(self) -> FlatModel: self._iter = iter(self.mappings.items()) return self - def __next__(self): + def __next__(self) -> tuple[PrimitiveModel, dict[str, str]]: model, mapping = next(self._iter) return model, {key: name.name for key, name in mapping.items()} diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 5d3055d5..6a259eb5 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from functools import reduce from itertools import product from types import FunctionType, GenericAlias, UnionType -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from .logical.base import BaseModel class NestedListType: @@ -33,11 +37,14 @@ def __init__(self, base_type: type | UnionType): self.base_type = base_type -def define_unique_names(models): +T = TypeVar("T", bound="BaseModel") + + +def define_unique_names(models: Iterable[T]) -> dict[T, str]: # TODO: Move this to Physical model (currently it is only used there) # TODO: Also add short-naming logic to this function - model_name_dict = {} - single_model_dict = {} + model_name_dict: dict[T, str] = {} + single_model_dict: dict[str, T] = {} model_count_dict: dict[str, int] = {} for model in models: diff --git a/mithril/models/models.py b/mithril/models/models.py index cf9debf7..a1f0b880 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -1306,7 +1306,7 @@ def __init__( self._freeze() - def __call__( # type: ignore[override] + def __call__( self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, @@ -1783,7 +1783,7 @@ def __init__( super().__init__(name=name) self.factory_inputs = kwargs - def __call__(self, **kwargs) -> ExtendInfo: # type: ignore[override] + def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: raise NotImplementedError("__call__ method not implemented!") @@ -1860,7 +1860,7 @@ def __init__( prev_cell = current_cell self._freeze() - def __call__( # type: ignore[override] + def __call__( self, input: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super(RNN, self).__call__(input=input, **model_keys) @@ -1987,7 +1987,7 @@ def __init__( self._freeze() - def __call__( # type: ignore[override] + def __call__( self, hidden_concat: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super(RNN, self).__call__(hidden_concat=hidden_concat, **model_keys) @@ -2091,7 +2091,7 @@ def __init__( self._freeze() - def __call__( # type: ignore[override] + def __call__( self, indices: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super().__call__(indices=indices, **model_keys) diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index deab9392..ca903f65 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -17,23 +17,23 @@ from copy import deepcopy from typing import Any, Self -from ..framework import ( +from ..framework import BaseModel, ExtendInfo, Model +from ..framework.common import ( NOT_GIVEN, - BaseModel, + TBD, Connection, ConnectionData, ConnectionType, - ExtendInfo, IOHyperEdge, IOKey, KeyType, - Model, + NotAvailable, + Table, UniadicRecord, Variadic, get_shapes, get_summary_shapes, ) -from ..framework.common import TBD, NotAvailable, Table from ..framework.logical import ( Buffer, Divide, @@ -531,7 +531,7 @@ def summary( if isinstance(self._model, Model): summary_kwargs["depth"] = depth - self._model.summary(**summary_kwargs) # type: ignore + self._model.summary(**summary_kwargs) name_mappings = self.get_unique_submodel_names() conn_info = self.extract_connection_info(name_mappings) diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 6ecf608c..9f98eb10 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -350,7 +350,7 @@ def train_model_to_dict(context: TrainModel) -> dict: return context_dict -def dict_to_trainmodel(context_dict: dict): +def dict_to_trainmodel(context_dict: dict) -> BaseModel: model = dict_to_model(context_dict["model"]) assert isinstance(model, Model), "TrainModel requires a Model object!" diff --git a/mypy.ini b/mypy.ini index dfb40220..bd7544cc 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,5 @@ [mypy] -check_untyped_defs = True -enable_incomplete_feature = NewGenericSyntax -ignore_missing_imports = True +strict = True +warn_return_any = False +exclude = ^tests/ + diff --git a/releases/generate_changelog.py b/releases/generate_changelog.py index b9becbca..3a51c92a 100644 --- a/releases/generate_changelog.py +++ b/releases/generate_changelog.py @@ -34,7 +34,7 @@ import sys from contextlib import redirect_stdout -from github import Github # type: ignore +from github import Github def print_pulls(repo_name, title, pulls): diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index 6c5e6030..717c8302 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -61,7 +61,7 @@ def test_data_store_1(): assert pm.data_store.data_values.keys() == {"input"} assert (pm.data_store.data_values[key].value == value).all() # type: ignore [union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -95,7 +95,7 @@ def test_data_store_1_numpy(): } assert (pm.data_store.data_values[key].value == value).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -112,7 +112,7 @@ def test_data_store_3(): assert pm.data_store.data_values.keys() == {"output_1"} assert (pm.data_store.data_values["output_1"] == backend.array(6.0)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "input", "weight", @@ -205,7 +205,7 @@ def test_data_store_7(): assert pm.data_store.data_values.keys() == {"input"} assert (res["output"] == value).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -221,7 +221,7 @@ def test_data_store_8(): assert pm.data_store.data_values.keys() == {"output1"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"input"} @@ -238,7 +238,7 @@ def test_data_store_9(): assert pm.data_store.data_values.keys() == {"output1"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"input"} @@ -255,7 +255,7 @@ def test_data_store_10(): assert pm.data_store.data_values.keys() == {"input", "output2"} assert (pm.data_store.data_values["output2"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -272,7 +272,7 @@ def test_data_store_11(): assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert (pm.data_store.data_values["output3"] == backend.sigmoid(value) + 2).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "right", "input", @@ -297,7 +297,7 @@ def test_data_store_13(): assert pm.data_store.data_values.keys() == {"out"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"left", "right"} infered_value = pm.data_store.data_values["out"] @@ -331,7 +331,7 @@ def test_data_store_14(): ) assert pm.data_store.data_values.keys() == {"input1", "out2"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "output_0", @@ -385,7 +385,7 @@ def test_data_store_15(): ) assert pm.data_store.data_values.keys() == {"input1", "out2"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "output_6", @@ -439,7 +439,7 @@ def test_data_store_16(): "output_cache", } assert pm.data_store.runtime_static_keys == {"input"} - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -467,7 +467,7 @@ def test_data_store_17(): assert pm.data_store.data_values.keys() == {"output_0_cache", "output_cache"} assert pm.data_store.runtime_static_keys == {"right"} - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -498,7 +498,7 @@ def test_data_store_18(): assert pm.data_store.data_values.keys() == set() assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -529,7 +529,7 @@ def test_data_store_19(): assert pm.data_store.data_values.keys() == set() assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -560,5 +560,5 @@ def test_data_store_20(): assert pm.data_store.data_values.keys() == {"tensor_out"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == {"left", "output_1"} diff --git a/tests/scripts/test_flatmodel.py b/tests/scripts/test_flatmodel.py index a8214113..4e2d97b0 100644 --- a/tests/scripts/test_flatmodel.py +++ b/tests/scripts/test_flatmodel.py @@ -397,7 +397,7 @@ def test_integration_collision_from_different_levels(): pm_short = ml.compile(model, backend) pm_long = ml.compile(model, backend, use_short_namings=False) - input_short = {"d": backend.array([1, 2, 3]), "e_1": backend.array([4, 5, 6])} + input_short = {"d": backend.array([1, 2, 3]), "e": backend.array([4, 5, 6])} input_long = { "middle_d": backend.array([1, 2, 3]), "middle_e": backend.array([4, 5, 6]), @@ -406,7 +406,7 @@ def test_integration_collision_from_different_levels(): res_short = pm_short.evaluate(input_short) res_long = pm_long.evaluate(input_long) - expected_res = {"e": backend.array([5, 7, 9], dtype=ml.int64)} + expected_res = {"_e": backend.array([5, 7, 9], dtype=ml.int64)} - np.testing.assert_allclose(expected_res["e"], res_short["e"]) # type: ignore - np.testing.assert_allclose(expected_res["e"], res_long["e"]) # type: ignore + np.testing.assert_allclose(expected_res["_e"], res_short["_e"]) # type: ignore + np.testing.assert_allclose(expected_res["_e"], res_long["e"]) # type: ignore diff --git a/tests/utils.py b/tests/utils.py index 24b65b0d..2dfaf897 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -93,8 +93,8 @@ def check_physical_models( assert backend.all(value.value == pm_2.data[key].value) # type: ignore assert pm_1.data_store.cached_data.keys() == pm_2.data_store.cached_data.keys() assert ( - pm_1.data_store._intermediate_non_differentiables._table.keys() - == pm_2.data_store._intermediate_non_differentiables._table.keys() + pm_1.data_store.intermediate_non_differentiables._table.keys() + == pm_2.data_store.intermediate_non_differentiables._table.keys() ) assert ( pm_1.data_store.runtime_static_keys == pm_2.data_store.runtime_static_keys From 941b09119618312c36bddd61ed1b53192b57a10f Mon Sep 17 00:00:00 2001 From: mehmetozsoy1 Date: Tue, 31 Dec 2024 14:49:09 +0300 Subject: [PATCH 26/26] Jacobian codes are removed from codebase --- mithril/__init__.py | 3 - mithril/framework/physical/model.py | 99 +------------------ .../test_compile_keys_consistencies.py | 35 +++---- tests/scripts/test_data_store.py | 8 -- tests/scripts/test_inference.py | 1 - tests/scripts/test_pm.py | 1 - 6 files changed, 18 insertions(+), 129 deletions(-) diff --git a/mithril/__init__.py b/mithril/__init__.py index f1589072..1dc13bcc 100644 --- a/mithril/__init__.py +++ b/mithril/__init__.py @@ -100,7 +100,6 @@ def compile( constant_keys: PhysicalConstantType[DataType] | None = None, data_keys: Iterable[str | Connection] | None = None, discard_keys: Iterable[str | Connection] | None = None, - jacobian_keys: Iterable[str | Connection] | None = None, trainable_keys: Iterable[str | Connection] | None = None, shapes: PhysicalShapeType | None = None, inference: builtins.bool = False, @@ -135,7 +134,6 @@ def compile( constant_keys = constant_keys if constant_keys is not None else dict() data_keys = set(data_keys) if data_keys is not None else set() discard_keys = set(discard_keys) if discard_keys is not None else set() - jacobian_keys = set(jacobian_keys) if jacobian_keys is not None else set() shapes = shapes if shapes is not None else dict() trainable_keys = set(trainable_keys) if trainable_keys is not None else set() @@ -146,7 +144,6 @@ def compile( data_keys=data_keys, constant_keys=constant_keys, trainable_keys=trainable_keys, - jacobian_keys=jacobian_keys, discard_keys=discard_keys, shapes=shapes, inference=inference, diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 0a1363a1..9b191356 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -15,10 +15,9 @@ import math import random import warnings -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from functools import partial, reduce from ...backends.backend import Backend, ParallelBackend from ...core import DataType, GenericDataType @@ -87,7 +86,6 @@ def __init__( data_keys: StringOrConnectionSetType, constant_keys: PhysicalConstantType[DataType], trainable_keys: StringOrConnectionSetType, - jacobian_keys: StringOrConnectionSetType, shapes: PhysicalShapeType, inference: bool, safe_shapes: bool, @@ -155,7 +153,6 @@ def __init__( _trainable_keys = {self._convert_key(model, key) for key in trainable_keys} _discard_keys = {self._convert_key(model, key) for key in discard_keys} _shapes = {self._convert_key(model, k): v for k, v in shapes.items()} - _jacobian_keys = {self._convert_key(model, key) for key in jacobian_keys} # Check provided constant and data_keys do not have # any preset value. Note that this check is done after key conversions. @@ -164,9 +161,7 @@ def __init__( self._check_overridden_nontrainable_keys(model, constant_keys, data_keys) # Final validation process of provided keys. - self._validate_keys( - _constant_keys, _data_keys, _trainable_keys, _discard_keys, _jacobian_keys - ) + self._validate_keys(_constant_keys, _data_keys, _trainable_keys, _discard_keys) # Set provided non-differentiable and trainable tensor keys. self._non_differentiable_keys: set[str] = _constant_keys.keys() | _data_keys @@ -249,7 +244,6 @@ def __init__( self._pre_compile( constant_keys=_constant_keys, data_keys=_data_keys, - jacobian_keys=_jacobian_keys, shapes=_shapes, ) @@ -327,7 +321,6 @@ def _validate_keys( data_keys: set[str], trainable_keys: set[str], discard_keys: set[str], - jacobian_keys: set[str], ) -> None: # Make sure no common keys in constant_keys, data_keys, trainable_keys # and discard_keys. @@ -368,13 +361,6 @@ def _validate_keys( f"Invalid keys: {', '.join(str(key) for key in internal_discards)}." ) - # Given jacobian keys must be subset of input keys. - if jacobian_diff := (jacobian_keys - self._input_keys): - raise KeyError( - "Provided jacobian keys must be subset of the input keys. " - f"Invalid keys: {', '.join(str(key) for key in jacobian_diff)}." - ) - def get_shapes( self, model: BaseModel | None = None, @@ -514,15 +500,7 @@ def _pre_compile( constant_keys: dict[str, DataType | MainValueType], data_keys: set[str], shapes: PhysicalShapeType, - jacobian_keys: set[str], ): - if jacobian_keys and self.backend.is_manualgrad: - raise Exception( - "Jacobians are only calculated for the backends that have " - "autograd capability." - ) - - self.jacobian_keys = jacobian_keys self.ignore_grad_keys: set[str] = set() # Set given shapes. @@ -615,79 +593,6 @@ def generate_functions( ) self._generated_evaluate_all_fn: EvaluateAllType[DataType] | None = eval_all_fn - def create_jacobian_fn(self, generated_fn: Callable): - # TODO: Fix this method to make it picklable! - if self.backend.is_manualgrad: - raise ( - NotImplementedError( - "Currently Jacobian is not supported for manuel grad!" - ) - ) - - # TODO: Consider to JIT this function. - def multiplier(x, y): - return x * y - - def jacobian_fn( - inputs: dict[str, DataType], data: dict[str, DataType] | None = None - ): - # Function for calculating jacobians for the requested - # outputs stated in jacobian keys. We use more efficient - # jacobian method considerin input-output dimensionalities. - if data is None: - data = {} - - def jacobian_wrapper(input, output): - total_inputs = inputs | input - - return generated_fn(params=total_inputs, data=data)[output] - - jacobians: dict[str, dict[str, DataType]] = {} - - # Define default jacobian method as jacrev since - # output dimensionality is generally lower than input. - jacobian_method = self.backend.jacrev # type: ignore - - # Iterate over all requested outputs for Jacobian calculations. - for out in self.jacobian_keys: - jacobians[out] = {} - # Iterate over all trainable inputs. - - jacobian_par_fn = jacobian_method(partial(jacobian_wrapper, output=out)) - - for key in inputs: - # if all(isinstance(dim, int) for dim in self.shapes[out]) and all( - # isinstance(dim, int) for dim in self.shapes[key] - # ): - key_shp = self.shapes[key] - out_shp = self.shapes[out] - if ( - isinstance(key_shp, list) - and isinstance(out_shp, list) - and is_list_int(key_shp) - and is_list_int(out_shp) - ): - # If dimensions are known, jacrev is more efficient - # for wide Jacobian matrices where output dimensionalitiy - # is lower than input dimensionality. - # jacfwd is more efficient in oppisite condition. - cond = reduce(multiplier, out_shp) >= reduce( - multiplier, key_shp - ) - jacobian_method = [self.backend.jacrev, self.backend.jacfwd][ # type: ignore - cond - ] - # Provide input in dict format in order to get jacobians in dict - # format since all inputs are originally provided in dict format. - input = {key: inputs[key]} - # jacobians[out] |= jacobian_method( - # partial(jacobian_wrapper, output=out) - # )(input) - jacobians[out] |= jacobian_par_fn(input) - return jacobians - - return jacobian_fn - def infer_ignore( self, weak_keys: set[str], diff --git a/tests/scripts/test_compile_keys_consistencies.py b/tests/scripts/test_compile_keys_consistencies.py index 05e07639..b96110ca 100644 --- a/tests/scripts/test_compile_keys_consistencies.py +++ b/tests/scripts/test_compile_keys_consistencies.py @@ -34,7 +34,6 @@ def test_dollar_sign_str(): "constant_keys", "data_keys", "discard_keys", - "jacobian_keys", "trainable_keys", "shapes", ]: @@ -73,7 +72,6 @@ def test_connection_not_found(): "constant_keys", "data_keys", "discard_keys", - "jacobian_keys", "trainable_keys", "shapes", ]: @@ -105,7 +103,6 @@ def test_string_not_found(): "constant_keys", "data_keys", "discard_keys", - "jacobian_keys", "trainable_keys", "shapes", ]: @@ -256,22 +253,22 @@ def test_discard_keys_input_and_outputs_only(): ) -def test_jacobian_keys_inputs_only(): - """jacobian_keys can not include any keys - other than the inputs of the model. - """ - model = Model() - model += (lin_model := Linear(1, True))(input="input", output="lin_out") - model += Multiply()(output=IOKey(name="output")) - - backend = TorchBackend() - with pytest.raises(KeyError) as err_info: - ml_compile(model, backend, jacobian_keys={lin_model.output, "input"}) - assert ( - str(err_info.value) - == "'Provided jacobian keys must be subset of the input keys. " - "Invalid keys: lin_out.'" - ) +# def test_jacobian_keys_inputs_only(): +# """jacobian_keys can not include any keys +# other than the inputs of the model. +# """ +# model = Model() +# model += (lin_model := Linear(1, True))(input="input", output="lin_out") +# model += Multiply()(output=IOKey(name="output")) + +# backend = TorchBackend() +# with pytest.raises(KeyError) as err_info: +# ml_compile(model, backend, jacobian_keys={lin_model.output, "input"}) +# assert ( +# str(err_info.value) +# == "'Provided jacobian keys must be subset of the input keys. " +# "Invalid keys: lin_out.'" +# ) def test_iterable_type_keys(): diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index be1a53d7..51f4512a 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -47,7 +47,6 @@ def test_data_store_1(): data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -77,7 +76,6 @@ def test_data_store_1_numpy(): data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -139,7 +137,6 @@ def test_data_store_4(): data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -426,7 +423,6 @@ def test_data_store_16(): data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -459,7 +455,6 @@ def test_data_store_17(): data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -490,7 +485,6 @@ def test_data_store_18(): data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -521,7 +515,6 @@ def test_data_store_19(): data_keys=set(), constant_keys={"left": left, "right": right}, trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, @@ -552,7 +545,6 @@ def test_data_store_20(): data_keys=set(), constant_keys={"left": left, "right": right}, trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=False, safe_shapes=True, diff --git a/tests/scripts/test_inference.py b/tests/scripts/test_inference.py index 332bf5fa..3ab90789 100644 --- a/tests/scripts/test_inference.py +++ b/tests/scripts/test_inference.py @@ -63,7 +63,6 @@ def test_discard_keys_inference(case: str) -> None: data_keys=set(), constant_keys=dict(), trainable_keys=set(), - jacobian_keys=set(), shapes=dict(), inference=True, safe_shapes=True, diff --git a/tests/scripts/test_pm.py b/tests/scripts/test_pm.py index 071b4263..c2df8900 100644 --- a/tests/scripts/test_pm.py +++ b/tests/scripts/test_pm.py @@ -73,7 +73,6 @@ def test_set_random_keys(): data_keys=set(), constant_keys={}, trainable_keys=set(), - jacobian_keys=set(), inference=False, safe_shapes=False, safe_names=False,