From 0436018c2e990a2c42159b48997ff51e3c63b976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Sun, 26 Jan 2025 19:10:44 +0300 Subject: [PATCH 1/2] Edge type polymorphism generalized. --- mithril/framework/codegen/numpy_gen.py | 6 +- mithril/framework/codegen/python_gen.py | 4 +- mithril/framework/common.py | 399 +++++++++++++----- mithril/framework/constraints.py | 68 +-- mithril/framework/logical/base.py | 34 +- .../framework/logical/essential_primitives.py | 191 +++++---- mithril/framework/logical/model.py | 12 +- mithril/framework/logical/primitive.py | 16 +- mithril/framework/physical/data_store.py | 8 +- mithril/framework/physical/model.py | 22 +- mithril/framework/utils.py | 152 +------ mithril/models/models.py | 278 ++++++------ mithril/models/primitives.py | 215 ++++++---- mithril/models/train_model.py | 6 +- mithril/utils/dict_conversions.py | 12 +- mithril/utils/func_utils.py | 4 +- tests/scripts/helper.py | 5 +- tests/scripts/test_constant_inputs.py | 56 ++- tests/scripts/test_constr_counter.py | 15 +- tests/scripts/test_constraints.py | 10 +- ...peredge_my_tensor.py => test_hyperedge.py} | 152 +++++-- tests/scripts/test_io_key.py | 25 +- tests/scripts/test_process_sequence.py | 6 +- .../test_randomized_models_all_backends.py | 5 +- tests/scripts/test_ref_counts.py | 27 +- tests/scripts/test_scripts.py | 57 ++- tests/scripts/test_shapes.py | 166 +++++--- tests/scripts/test_type_coercion.py | 38 +- tests/scripts/test_type_consistencies.py | 67 ++- tests/scripts/test_utils.py | 7 +- 30 files changed, 1207 insertions(+), 856 deletions(-) rename tests/scripts/{test_hyperedge_my_tensor.py => test_hyperedge.py} (70%) diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index 8a81ed7e..a95e2303 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -16,13 +16,12 @@ import keyword from collections.abc import Callable from functools import partial -from typing import Any, Literal, overload +from typing import Any, Literal, get_origin, overload import numpy as np from ...backends.with_manualgrad.numpy_backend import NumpyBackend from ...framework.physical.model import PhysicalModel -from ...framework.utils import find_intersection_type from ...utils.func_utils import is_make_array_required, prepare_function_args from ..common import ( DataEvalType, @@ -34,6 +33,7 @@ LossKey, ParamsEvalType, Tensor, + find_intersection_type, is_type_adjustment_required, ) from ..logical import PrimitiveModel @@ -318,7 +318,7 @@ def generate_evaluate_gradients( key for key in all_ignored_keys if key in self.pm.data - and self.pm.data[key].edge_type is Tensor + and get_origin(self.pm.data[key].edge_type) is Tensor and find_intersection_type(self.pm.data[key].value_type, float) } diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index f4d5d987..7dde24fe 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -19,7 +19,7 @@ from collections.abc import Callable from functools import partial from posixpath import basename, splitext -from typing import Any, Generic, Literal, Protocol, overload +from typing import Any, Generic, Literal, Protocol, get_origin, overload from ...backends.backend import ParallelBackend from ...core import DataType, Dtype @@ -301,7 +301,7 @@ def generate_imports(self) -> list[ast.stmt]: def is_static_scalar(self, key: str) -> bool: return ( key in self.pm.data_store.cached_data - and self.pm.data[key].edge_type != Tensor + and get_origin(self.pm.data[key].edge_type) != Tensor and self.pm.data[key].edge_type != Dtype and not isinstance(self.pm.data_store.cached_data[key], enum.Enum) ) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 18afca6d..b971baf1 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -27,7 +27,9 @@ Literal, Protocol, TypedDict, + TypeGuard, TypeVar, + Union, get_args, get_origin, overload, @@ -42,7 +44,6 @@ from ..utils.utils import PaddingType from .utils import ( align_shapes, - find_intersection_type, find_type, sort_type, ) @@ -221,13 +222,8 @@ class KeyType(Enum): TypeVarTensorType = TypeVar( "TypeVarTensorType", - int, - float, - bool, - int | float, - int | bool, - float | bool, - int | float | bool, + bound=int | float | bool, + covariant=True, ) # Availale types for Tensor type ("_type" attribute of Tensor class). _TensorTypes = type[int] | type[float] | type[bool] | UnionType @@ -475,7 +471,7 @@ def _delete_node(remaining: ShapeNode, deleted: ShapeNode) -> Updates: updates = remaining.merge(deleted) # Iterate over deleted nodes referees to remove deleted node. for ref in deleted.referees: - if ref.edge_type is not Tensor: + if get_origin(ref.edge_type) is not Tensor: raise ValueError("Non-tensor edges cannot have any shape.") assert isinstance(ref._value, Tensor) ref._value.shape = remaining @@ -611,7 +607,7 @@ def get_shapes( shapes: dict[str, ShapeTemplateType | list[ShapeTemplateType] | None] = {} for key, data in data_dict.items(): key_name = key_mappings.get(key, key) - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert data.shape is not None shapes[key_name] = data.shape.get_shapes( uniadic_keys, varadic_keys, symbolic, verbose @@ -635,8 +631,8 @@ def _find_type(value: ScalarValueType) -> ScalarType: ... def _find_type( - value: Tensor[Any] | ScalarValueType | range, -) -> type[Tensor[Any]] | ScalarType | list[int]: + value: Tensor[TypeVarTensorType] | ScalarValueType | range, +) -> type[Tensor[TypeVarTensorType]] | ScalarType | list[int]: typ: type if isinstance(value, Tensor): typ = Tensor[value.type] # type: ignore @@ -678,7 +674,7 @@ def process_value( dominant_type: type[bool] | type[int] | type[float] = bool for item in value: - # Recursively determine the shape, value and type of sublists. + # Recursively determine the shape, value and type of sub items. sub_shape, sub_val, sub_type = process_value(item) assert not isinstance(sub_val, Constant) @@ -695,6 +691,167 @@ def process_value( return [len(result)] + sub_shape, result, dominant_type +def find_intersection_type( + type_1: type | UnionType | GenericAlias | type[Tensor[int | float | bool]], + type_2: type | UnionType | GenericAlias | type[Tensor[int | float | bool]], +) -> type | UnionType | GenericAlias | type[Tensor[int | float | bool]] | None: + # ToBeDetermined type can be coerced to all types. + if type_1 is ToBeDetermined: + return type_2 + if type_2 is ToBeDetermined: + return type_1 + + # First find direct intersections. + subtypes_1 = ( + set(get_args(type_1)) if get_origin(type_1) in (UnionType, Union) else {type_1} + ) + subtypes_2 = ( + set(get_args(type_2)) if get_origin(type_2) in (UnionType, Union) else {type_2} + ) + intersect = subtypes_1 & subtypes_2 + + # Handle coercion of Any (typing.Any) type to all other types. + if Any in subtypes_1: + intersect.update(subtypes_2) + subtypes_1.remove(Any) + if Any in subtypes_2: + intersect.update(subtypes_1) + subtypes_2.remove(Any) + + # if one of the subtypes have list or tuple without an origin (without square + # brackets, ex: tuple), look for other set if it contains corresponding type + # with origin (ex: tuple[int, int]) if the set contains it, add that type with + # origin (since it contains more information) + + for s_types in (subtypes_1, subtypes_2): + other_set = subtypes_2 if s_types == subtypes_1 else subtypes_1 + for orig_type in (list, tuple, range): + if orig_type in s_types: + for typ in other_set: + if isinstance(typ, GenericAlias): + if typ.__origin__ == orig_type: + intersect.add(typ) + elif typ.__origin__ == Sequence: + if orig_type is range: + if find_intersection_type(int, typ.__args__[0]): + intersect.add(range) + else: + intersect.add( + orig_type[reduce(lambda x, y: x | y, typ.__args__)] # type: ignore + ) + + # Take tuple types from remaining sets and find intesection types + # of all consistent pairs of cartesian product. + for typ_1 in subtypes_1.difference(intersect): + # if not isinstance(typ_1, GenericAlias): + if get_origin(typ_1) is not None: + args_1 = typ_1.__args__ + for typ_2 in subtypes_2.difference(intersect): + # if not isinstance(typ_2, GenericAlias): + if get_origin(typ_2) is not None: + args_2 = typ_2.__args__ + if typ_1.__origin__ == typ_2.__origin__: + if len(args_1) == 0 or len(args_2) == 0: + # if one of the lengths of the args_1 and args_2 are zero, + # this means one of the types with origin are empty list or + # tuple, in that case, take the empty one (tuple[()], or + # list[()]) as intersection type. + common: Any = typ_1.__origin__[()] + + elif typ_1.__origin__ is tuple: + ellipsis_1 = ... in args_1 + ellipsis_2 = ... in args_2 + common = False + if ellipsis_1 and ellipsis_2: + common = find_intersection_type(args_1[0], args_2[0]) + if common: + common = [common, ...] + elif ellipsis_1: + # Remove ellipsis and replace it with base type + # as many times as length of args_2 + common = [ + find_intersection_type(args_1[0], args_2[i]) + for i in range(len(args_2)) + ] + elif ellipsis_2: + # Remove ellipsis and replace it with base type + # as many times as length of args_1 + common = [ + find_intersection_type(args_1[i], args_2[0]) + for i in range(len(args_1)) + ] + elif len(args_1) == len(args_2): + common = [ + find_intersection_type(args_1[i], args_2[i]) + for i in range(len(args_1)) + ] + if common and None not in common: + intersect.add(tuple[*common]) + + elif typ_1.__origin__ in (list, Tensor): + if len(args_2) > 1 or len(args_1) > 1: + raise TypeError( + "args of type list cannot take more than 1 element" + ) + else: + common = find_intersection_type(args_1[0], args_2[0]) + if common: + intersect.add( + list[common] + if typ_1.__origin__ is list + else Tensor[common] + ) + # TODO: Below code is duplicate of above code, refactor it. + elif typ_1.__origin__ is Sequence: + if len(args_2) > 1 or len(args_1) > 1: + raise TypeError( + "args of type Sequence cannot take " + "more than 1 element" + ) + else: + common = find_intersection_type(args_1[0], args_2[0]) + if common: + intersect.add(Sequence[common]) + + elif Sequence in (typ_1.__origin__, typ_2.__origin__): + if typ_1.__origin__ == Sequence: + coerced_type = typ_1 + other_type = typ_2 + else: + coerced_type = typ_2 + other_type = typ_1 + + other_origin = other_type.__origin__ + if other_origin is not Tensor: + # Sequence type can only be replaced with list or tuple. + assert isinstance(other_origin, type(list) | type(tuple)) + + # Replace Sequence with other origin type and resend them + # to find_intersection_type. + inner_args = reduce( + lambda x, y: x | y, coerced_type.__args__ + ) + updated_type = ( + other_origin[inner_args] + if other_type.__origin__ is list + else other_origin[inner_args, ...] + ) + common = find_intersection_type(updated_type, other_type) + if common: + intersect.add(common) + + if intersect: + result = reduce(lambda x, y: x | y, intersect) + return result + return None + + +def is_tensor_type( + typ: type | UnionType | GenericAlias | type[Tensor[int | float | bool]] | None, +) -> TypeGuard[type[Tensor[int | float | bool]]]: + return get_origin(typ) is Tensor or typ is Tensor + + class Tensor(Generic[TypeVarTensorType]): def __init__( self, @@ -715,14 +872,14 @@ def __init__( def set_type(self, typ: _TensorTypes) -> Updates: updates = Updates() - # if self.type != typ: - # new_type = find_intersection_type(typ, self.type) if self.type != (new_type := find_intersection_type(typ, self.type)): if not new_type: raise TypeError( f"Acceptable types are {sort_type(self.type)}, but " - f"{sort_type(typ)} type value is provided!" + f"{sort_type(typ)} type is provided!" ) + # TODO: Update below assertion! + assert not (is_tensor_type(new_type) or isinstance(new_type, GenericAlias)) self.type = new_type # Add all referee edges into the updates. for edge in self.referees: @@ -749,7 +906,7 @@ def set_value(self, value: TensorValueType) -> Updates: self.value = val return updates - def match(self, other: Tensor[Any]) -> Updates: + def match(self, other: Tensor[int | float | bool]) -> Updates: updates = Updates() if self is not other: updates |= self.set_type(other.type) @@ -785,13 +942,13 @@ def match_shapes(self, node: ShapeNode) -> Updates: class IOHyperEdge: - _type: type[Tensor[Any]] | ScalarType - _value: Tensor[Any] | ScalarValueType | ToBeDetermined + _type: type[Tensor[int | float | bool]] | ScalarType + _value: Tensor[int | float | bool] | ScalarValueType def __init__( self, - type: type[Tensor[Any]] | ScalarType = ToBeDetermined, - value: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + type: type[Tensor[int | float | bool]] | ScalarType = ToBeDetermined, + value: Tensor[int | float | bool] | ScalarValueType = TBD, key_origin: str | None = None, interval: list[float | int] | None = None, ) -> None: @@ -799,6 +956,8 @@ def __init__( self.shape_constraints: set[Constraint] = set() self.type_constraints: set[Constraint] = set() self._temp_shape: ShapeRepr | None = None # set random repr + self.differentiable: bool = False + self.interval: list[float | int] | None = interval # Initially set type and value as not determined yet. self._type = ToBeDetermined self._value = TBD @@ -807,8 +966,16 @@ def __init__( # If any value is provided, set it. if value is not TBD: self.set_value(value) - self.interval: list[float | int] | None = interval - self.differentiable: bool = self.value is TBD if self._type is Tensor else False + + @property + def is_polymorphic(self) -> bool: + # Returns if the edge is of polymorphic type or not. + if self._type is ToBeDetermined: + return True + # Look for possible tensor and scalar types. + tensor_possible = find_intersection_type(Tensor[int | float | bool], self._type) + scalar_possible = find_intersection_type(ScalarValueType, self._type) + return None not in (tensor_possible, scalar_possible) @property def is_non_diff(self) -> bool: @@ -840,24 +1007,27 @@ def value_type(self) -> _TensorTypes | ScalarType: return self._type # type: ignore @property - def edge_type(self) -> type[Tensor[Any]] | ScalarType: + def edge_type(self) -> type[Tensor[int | float | bool]] | ScalarType: return self._type - def _create_and_set_tensor_value(self, typ: _TensorTypes) -> Updates: + def _create_and_set_tensor_value( + self, typ: type[Tensor[int | float | bool]] + ) -> Updates: updates = Updates() # Create a new tensor and add self to its referees # and shape referees. - tensor = Tensor(type=typ) + tensor_typ = get_args(typ)[0] + tensor: Tensor[int | float | bool] = Tensor(type=tensor_typ) tensor.referees.add(self) tensor.shape.referees.add(self) # Set type of the edge to Tensor. - self._type = Tensor + self._type = typ updates.add(self, UpdateType.TYPE) self._value = tensor return updates def _value_compatible( - self, other_value: Tensor[Any] | ScalarValueType | ToBeDetermined + self, other_value: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined ) -> bool: if self._value is not TBD: if type(self._value) is not type(other_value): @@ -868,86 +1038,80 @@ def _value_compatible( return self.value is TBD or self.value == _other_value return True - def set_type(self, typ: type[Tensor[Any]] | ScalarType) -> Updates: + def set_type(self, typ: type[Tensor[int | float | bool]] | ScalarType) -> Updates: updates = Updates() - # Tensor type setting. - if (is_generic := (get_origin(typ) is Tensor)) or typ is Tensor: - if not (self._type is Tensor or self._type is ToBeDetermined): - raise TypeError("Can not set Tensor type to a Scalar edge.") - - available_types = get_args(typ)[0] if is_generic else int | float | bool - if not isinstance(self._value, Tensor): - # This is the case when the base type is not determined yet, - # meaning it can be of any type. So, if it is requested - # to set type to Tensor, we need to create a new Tensor - # with a shape of Variadic type. - updates |= self._create_and_set_tensor_value(available_types) - else: - # Set type of Tensor object using available_types - # assert isinstance(self._value, Tensor) - updates |= self._value.set_type(available_types) - assert isinstance(self._value, Tensor) # TODO: Duplicate check. - self.differentiable = (self.value is TBD) and bool( - find_intersection_type(float, self._value.type) - ) - return updates - - elif self._type is Tensor and typ is not ToBeDetermined: - raise TypeError("Can not set Scalar type to a Tensor edge.") - - elif not (typ is ToBeDetermined or self._type == typ): - # Scalar type setting. - updates.add(self, UpdateType.TYPE) - new_type = ( - find_intersection_type(typ, self._type) - if self._type is not ToBeDetermined - else typ - ) - if not new_type: + if self._type != typ: + new_type = find_intersection_type(self._type, typ) + # If new_type is not different from the current type, return updates. + if self._type == new_type: + return updates + # None new_type means incompatible types are provided, + # raise TypeError. + if new_type is None: raise TypeError( f"Acceptable types are {sort_type(self._type)}, but " - f"{sort_type(typ)} type value is provided!" + f"{sort_type(typ)} type is provided!" ) + elif is_tensor_type(new_type): + # new_type is strictly a tensor type. + if not isinstance(self._value, Tensor): + # This is the case when the base type is not determined yet, + # meaning it can be of any type. So, if it is requested + # to set type to Tensor, we need to create a new Tensor + # with a shape of Variadic type. + updates |= self._create_and_set_tensor_value(new_type) + else: + # Set type of Tensor object using available_types + updates |= self._value.set_type(get_args(new_type)[0]) + # Add self as type update, set new type and update differentiability. + updates.add(self, UpdateType.TYPE) self._type = new_type - self.differentiable = False + self.differentiable = (self.value is TBD) and bool( + find_intersection_type(Tensor[float], self._type) + ) return updates def set_value( - self, value: Tensor[Any] | ScalarValueType | ToBeDetermined + self, value: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined ) -> Updates: updates = Updates() + tensor_possible = find_intersection_type(Tensor[int | float | bool], self._type) # If type of self and type of value is not compatible, raise an error. - if isinstance(value, Tensor) and (self._type not in (Tensor, ToBeDetermined)): + if isinstance(value, Tensor) and not ( + self._type is ToBeDetermined or tensor_possible + ): raise ValueError("Can not set Tensor value to a Scalar edge.") - if not isinstance(value, Tensor) and self._type is Tensor: + if not isinstance(value, Tensor) and get_origin(self._type) is Tensor: raise ValueError("Can not set Scalar value to a Tensor edge.") - + # If any value different than self._value is provided, raise error. if not self._value_compatible(value): raise ValueError( f"Value is set before as {self.value}. A value can not be reset." ) - if isinstance(value, Tensor) or not ( - isinstance(value, ToBeDetermined) or self.value == value - ): - # If both values are Tensor, match them. - if isinstance(self._value, Tensor) and isinstance(value, Tensor): - updates |= self._value.match(value) - elif self.edge_type is ToBeDetermined and isinstance(value, Tensor): - self._value = value - self._type = Tensor - # Add self to referees of value and shape. - self._value.referees.add(self) - self._value.shape.referees.add(self) - # Add self as a type update since type has just updated to Tensor. - updates.add(self, UpdateType.TYPE) - # TODO: When two edges set to the same tensor value using - # different Tensor objects, we need to merge their nodes into - # a single node. In order to track this, we need to add all - # uniadic symbols of all reprs to the updates. - for repr in value.shape.reprs: - for symbol in repr.prefix + repr.suffix: - updates.add(symbol) + if not (isinstance(value, ToBeDetermined) or self._value == value): + # Note that two tensor objects having same value are not equal. + # Tensor values always have to be matched with the existing one + # or set as the new value. + if isinstance(value, Tensor): + if isinstance(self._value, Tensor): + # If both values are Tensor, match them. + updates |= self._value.match(value) + elif self.is_polymorphic: + self._value = value + self._type = Tensor[self._value.type] # type: ignore + # Add self to referees of value and shape. + self._value.referees.add(self) + self._value.shape.referees.add(self) + # Add self as a type update since type has just updated to Tensor. + updates.add(self, UpdateType.TYPE) + # TODO: When two edges set to the same tensor value using + # different Tensor objects, we need to merge their nodes into + # a single node. In order to track this, we need to add all + # uniadic symbols of all reprs to the updates. + for repr in value.shape.reprs: + for symbol in repr.prefix + repr.suffix: + updates.add(symbol) else: updates |= self.set_type(_find_type(value)) self._value = value @@ -1300,31 +1464,37 @@ def __init__( @dataclass class BaseKey: - value: Tensor[Any] | ScalarValueType | TensorValueType | ToBeDetermined | str = TBD + value: ( + Tensor[int | float | bool] + | ScalarValueType + | TensorValueType + | ToBeDetermined + | str + ) = TBD shape: ShapeTemplateType | None = None - type: UnionType | type | type[Tensor[Any]] | ScalarType | None = None + type: UnionType | type | type[Tensor[int | float | bool]] | ScalarType | None = None interval: list[float | int] | None = None - # TODO: Add __post_init__ to check types and values - # TODO: Add __post_init__ to check types and values - # def __post_init__(self) -> None: - # if not isinstance(self.value, ToBeDetermined): - # value_type = _find_type(self.value) - # if self.type is not None and - # find_intersection_type(value_type, self.type) is None: - # raise TypeError( - # f"type of the given value and given type does not match. Given " - # f"type is {self.type} while type of value is {value_type}" - # ) + def __post_init__(self) -> None: + # Convert to generic Tensor type if Tensor type is provided. + if self.type is Tensor: + self.type = Tensor[int | float | bool] class IOKey(TemplateBase): def __init__( self, name: str | None = None, - value: Tensor[Any] | ScalarValueType | ToBeDetermined | str = TBD, + value: Tensor[int | float | bool] + | ScalarValueType + | ToBeDetermined + | str = TBD, shape: ShapeTemplateType | None = None, - type: UnionType | type | type[Tensor[Any]] | ScalarType | None = None, + type: UnionType + | type + | type[Tensor[int | float | bool]] + | ScalarType + | None = None, expose: bool | None = None, interval: list[float | int] | None = None, connections: set[Connection | str] | None = None, @@ -1333,9 +1503,12 @@ def __init__( # If shape is provided, type should be Tensor. if shape is not None: if type is None: - type = Tensor - elif (type is not Tensor) and (get_origin(type) is not Tensor): + type = Tensor[int | float | bool] + elif get_origin(type) is not Tensor: raise TypeError("Shape can not be provided for a non-tensor type!") + elif type is Tensor: + # Convert to generic Tensor type if Tensor type is provided. + type = Tensor[int | float | bool] self.name = name self.expose = expose @@ -1413,7 +1586,7 @@ def __eq__(self, other: object) -> bool: def set_differentiable(self, differentiable: bool = True) -> None: # TODO: Move this method to Model class as set_shapes, set_types etc. - if self.metadata.edge_type is Tensor: + if get_origin(self.metadata.edge_type) is Tensor: self.metadata.differentiable = differentiable elif differentiable: if self.metadata.edge_type is not ToBeDetermined: @@ -1429,7 +1602,7 @@ def set_differentiable(self, differentiable: bool = True) -> None: | EllipsisType | tuple[slice | int | None | EllipsisType | TemplateBase, ...] | None - | Tensor[Any] + | Tensor[int | float | bool] ) @@ -1440,7 +1613,7 @@ def set_differentiable(self, differentiable: bool = True) -> None: | NullConnection | IOKey | Connection - | Tensor[Any] + | Tensor[int | float | bool] ) ConnectionInstanceType = ( @@ -1602,7 +1775,7 @@ def get_key_origin(self, key: str) -> str | None: def get_shape_node(self, key: str) -> ShapeNode: edge = self.get_metadata(key) - if edge.edge_type is not Tensor: + if get_origin(edge.edge_type) is not Tensor: raise ValueError("'Only Tensor type connections has shape!'") assert edge.shape is not None return edge.shape @@ -2596,6 +2769,8 @@ def get_most_informative_repr(self) -> ShapeRepr: ) ): most_informative_repr = repr + if most_informative_repr is None: + ... assert most_informative_repr is not None return most_informative_repr @@ -2948,7 +3123,7 @@ def __call__(self, keys: list[IOHyperEdge]) -> ConstrainResultType: status = False updates = Updates() if self.type == UpdateType.SHAPE: - tensor_keys = [key for key in keys if key.edge_type is Tensor] + tensor_keys = [key for key in keys if get_origin(key.edge_type) is Tensor] for reprs in product(*[key.shape.reprs for key in tensor_keys]): # type: ignore for idx, repr in enumerate(reprs): tensor_keys[idx]._temp_shape = repr diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index ae43a2e1..ff2d9c03 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -49,13 +49,10 @@ UpdateType, Variadic, _TensorTypes, - process_value, -) -from .utils import ( find_intersection_type, - find_list_base_type, - is_union, + process_value, ) +from .utils import find_list_base_type, is_union __all__ = [ "edge_type_constraint", @@ -182,7 +179,7 @@ def set_edge_type(edge: IOHyperEdge, new_type: Any) -> Updates: # Simply wraps new type into Tensor if edge_type is Tensor, # else sets directly. type = new_type - if edge.edge_type is Tensor: + if get_origin(edge.edge_type) is Tensor: type = Tensor[new_type] return edge.set_type(type) @@ -192,29 +189,36 @@ def edge_type_constraint( ) -> ConstrainResultType: updates = Updates() status = False + tensor_exists: bool = False + tensor_output: bool = get_origin(output.edge_type) is Tensor + for input in inputs: + if get_origin(input.edge_type) is Tensor: + tensor_exists = True + break # First set output edge_type to Tensor if any Tensor type inputs # exists. - input_edge_types = {input.edge_type for input in inputs} - if Tensor in input_edge_types: - if output.edge_type is not Tensor: - updates |= output.set_type(Tensor) - elif output.edge_type is Tensor: + if tensor_exists: + if not tensor_output: + updates |= output.set_type(Tensor[int | float | bool]) + elif tensor_output: # Reverse edge_type inference. # If there is only one untyped input, set it as Tensor. untyped_inputs = [ input for input in inputs if input.edge_type is ToBeDetermined ] if len(untyped_inputs) == 1: - updates |= untyped_inputs.pop().set_type(Tensor) + updates |= untyped_inputs.pop().set_type(Tensor[int | float | bool]) status = True - elif output.edge_type is not ToBeDetermined: + elif output.edge_type is not ToBeDetermined and ( + find_intersection_type(Tensor[int | float | bool], output.edge_type) is None + ): # Scalar output means all inputs are scalar. for input in inputs: updates |= input.set_type(ScalarValueType) status = True - # If no ToBeDetermined edge_type exists, return True. - if ToBeDetermined not in input_edge_types: + # If no polymorphic edge_type exists, return True. + if all({not input.is_polymorphic for input in inputs}): status = True return status, updates @@ -633,18 +637,17 @@ def indexer_initial_type_constraint( updates = Updates() edge_types = {input.edge_type, output.edge_type} if edge_types != {ToBeDetermined}: - if Tensor not in edge_types: + if Tensor not in {get_origin(typ) for typ in edge_types}: # Meaning that indexing scalar type data. Set general # scalar type constraints on all arguments. - # TODO: Types should be more specific. updates |= output.set_type(int | float | list[Any] | tuple[Any, ...]) updates |= input.set_type(list[Any] | tuple[Any, ...]) status = True else: - tensor_edge = input if input.edge_type is Tensor else output + tensor_edge = input if get_origin(input.edge_type) is Tensor else output assert isinstance(tensor_edge._value, Tensor) - typ: type[Tensor[Any]] = Tensor[tensor_edge.value_type] # type: ignore + typ: type[Tensor[int | float | bool]] = Tensor[tensor_edge.value_type] # type: ignore other_edge = (input, output)[tensor_edge is input] updates |= other_edge.set_type(typ) status = True @@ -656,7 +659,7 @@ def indexer_type_constraint( ) -> ConstrainResultType: status = False updates = Updates() - if input.edge_type not in (Tensor, ToBeDetermined): + if not (get_origin(input.edge_type) is Tensor or input.edge_type is ToBeDetermined): # Input is a non-tensor type. input_type = input.value_type output_type = output.value_type @@ -703,7 +706,7 @@ def indexer_type_constraint( updates |= output.set_type(inferred_out_type) status = not is_union(output.value_type) - elif input.edge_type is Tensor: + elif get_origin(input.edge_type) is Tensor: status = True return status, updates @@ -1412,8 +1415,9 @@ def bcast_helper( def _bcast( output: IOHyperEdge, left: IOHyperEdge, right: IOHyperEdge, index: int ) -> ConstrainResultType: - l_type = left.edge_type - r_type = right.edge_type + l_type = Tensor if get_origin(left.edge_type) is Tensor else left.edge_type + r_type = Tensor if get_origin(right.edge_type) is Tensor else right.edge_type + o_type = Tensor if get_origin(output.edge_type) is Tensor else output.edge_type if l_type is Tensor and r_type is Tensor: assert output._temp_shape is not None, "Output shape of broadcast is not set!" assert left._temp_shape is not None, "Left shape of broadcast is not set!" @@ -1421,15 +1425,15 @@ def _bcast( return bcast_helper( output._temp_shape, left._temp_shape, right._temp_shape, index ) - elif not ({Tensor, ToBeDetermined} & {l_type, r_type, output.edge_type}): + elif not ({Tensor, ToBeDetermined} & {l_type, r_type, o_type}): # Means all edges are scalar types. Simply return True # without any updates. return True, Updates() merge_edge: IOHyperEdge | None = None - if left.edge_type is Tensor: + if l_type is Tensor: merge_edge = left - elif right.edge_type is Tensor: + elif r_type is Tensor: merge_edge = right if merge_edge is not None: @@ -3683,7 +3687,7 @@ def tensor_item_constraint_helper( def indexer_constraints( output: IOHyperEdge, input: IOHyperEdge, index: IOHyperEdge ) -> ConstrainResultType: - if input.edge_type is Tensor: + if get_origin(input.edge_type) is Tensor: return tensor_item_constraints(output, input, index) elif input.edge_type is not ToBeDetermined: return scalar_item_constraints(output, input, index) @@ -3942,7 +3946,7 @@ def buffer_constraint(output: IOHyperEdge, input: IOHyperEdge) -> ConstrainResul typed_edge, other_edge = output, input if typed_edge is not None: - if typed_edge.edge_type is Tensor: + if typed_edge._value is not TBD: updates |= other_edge.set_value(typed_edge._value) else: updates |= other_edge.set_type(typed_edge.edge_type) @@ -3956,7 +3960,7 @@ def relational_operator_type_constraint( updates = Updates() status = False # Forward inference. - if Tensor in (input1.edge_type, input2.edge_type): + if Tensor in (get_origin(input1.edge_type), get_origin(input2.edge_type)): updates |= output.set_type(Tensor[bool]) status = True elif ToBeDetermined not in (input1.edge_type, input2.edge_type): @@ -3971,7 +3975,7 @@ def divide_type_constraint( updates = Updates() status = False # Forward inference. - if Tensor in (numerator.edge_type, denominator.edge_type): + if Tensor in (get_origin(numerator.edge_type), get_origin(denominator.edge_type)): updates |= output.set_type(Tensor[float]) status = True elif ToBeDetermined not in (numerator.edge_type, denominator.edge_type): @@ -3989,13 +3993,13 @@ def polynomial_kernel_constraint( # poly_coef update. if poly_coef.edge_type is not ToBeDetermined: coef_status = True - if poly_coef.edge_type is Tensor: + if get_origin(poly_coef.edge_type) is Tensor: assert poly_coef.shape is not None updates |= poly_coef.shape.set_values([]) # degree update. if degree.edge_type is not ToBeDetermined: degree_status = True - if degree.edge_type is Tensor: + if get_origin(degree.edge_type) is Tensor: assert degree.shape is not None updates |= degree.shape.set_values([]) return coef_status & degree_status, updates diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index 65c49941..4c3d73c4 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -97,7 +97,7 @@ def __init__(self, name: str | None = None, enforce_jit: bool = True) -> None: self.assigned_shapes: list[ShapesType] = [] self.assigned_types: dict[ str, - type | UnionType | ScalarType | Tensor[Any], + type | UnionType | ScalarType | Tensor[int | float | bool], ] = {} self.assigned_constraints: list[AssignedConstraintType] = [] self.conns = Connections() @@ -257,12 +257,10 @@ def _set_shapes( # Apply updates to the shape nodes. for key in chain(shapes, kwargs): node, _inner_key = shape_nodes[key] - if ( - metadata := self.conns.get_data(_inner_key) - ).edge_type is ToBeDetermined: + if (metadata := self.conns.get_data(_inner_key)).is_polymorphic: # If edge_type is not defined yet, set it to Tensor since # shape is provided. - updates |= metadata.set_type(Tensor) + updates |= metadata.set_type(Tensor[int | float | bool]) shape_node = self.conns.get_shape_node(_inner_key) assert shape_node is not None updates |= shape_node.merge(node) @@ -273,7 +271,9 @@ def _set_shapes( model.constraint_solver(updates) def _set_value( - self, key: ConnectionData, value: MainValueType | Tensor[Any] | str + self, + key: ConnectionData, + value: MainValueType | Tensor[int | float | bool] | str, ) -> Updates: """ Set value for the given connection. @@ -302,11 +302,13 @@ def set_shapes( def set_values( self, - config: Mapping[str | Connection, Tensor[Any] | MainValueType | str] - | Mapping[Connection, Tensor[Any] | MainValueType | str] - | Mapping[str, Tensor[Any] | MainValueType | str] + config: Mapping[ + str | Connection, Tensor[int | float | bool] | MainValueType | str + ] + | Mapping[Connection, Tensor[int | float | bool] | MainValueType | str] + | Mapping[str, Tensor[int | float | bool] | MainValueType | str] | None = None, - **kwargs: Tensor[Any] | MainValueType | str, + **kwargs: Tensor[int | float | bool] | MainValueType | str, ) -> None: """ Set multiple values in the model. @@ -346,18 +348,18 @@ def set_types( self, config: Mapping[ str | Connection, - type | UnionType | ScalarType | type[Tensor[Any]], + type | UnionType | ScalarType | type[Tensor[int | float | bool]], ] | Mapping[ Connection, - type | UnionType | ScalarType | type[Tensor[Any]], + type | UnionType | ScalarType | type[Tensor[int | float | bool]], ] | Mapping[ str, - type | UnionType | ScalarType | type[Tensor[Any]], + type | UnionType | ScalarType | type[Tensor[int | float | bool]], ] | None = None, - **kwargs: type | UnionType | ScalarType | type[Tensor[Any]], + **kwargs: type | UnionType | ScalarType | type[Tensor[int | float | bool]], ) -> None: """ Set types of any connection in the Model @@ -379,7 +381,7 @@ def set_types( # Initialize assigned shapes dictionary to store assigned shapes. assigned_types: dict[ str, - type | UnionType | ScalarType | Tensor[Any], + type | UnionType | ScalarType | Tensor[int | float | bool], ] = {} # Get the outermost parent as all the updates will happen here. @@ -390,6 +392,8 @@ def set_types( conn = self.conns.get_con_by_metadata(metadata) assert conn is not None inner_key = conn.key + if key_type is Tensor: + key_type = Tensor[int | float | bool] assigned_types[inner_key] = key_type updates |= metadata.set_type(key_type) # Store assigned types in the model. diff --git a/mithril/framework/logical/essential_primitives.py b/mithril/framework/logical/essential_primitives.py index facde87e..e34650e3 100644 --- a/mithril/framework/logical/essential_primitives.py +++ b/mithril/framework/logical/essential_primitives.py @@ -125,7 +125,7 @@ class Buffer(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -154,7 +154,7 @@ def __init__( n: int, *, name: str | None = None, - **kwargs: Tensor[Any] | ScalarValueType | ToBeDetermined, + **kwargs: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined, ) -> None: self.factory_args = {"n": n} key_definitions = { @@ -187,8 +187,8 @@ class ArithmeticOperation(PrimitiveModel): def __init__( self, formula_key: str, - left: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, - right: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -224,8 +224,12 @@ class Power(PrimitiveModel): def __init__( self, robust: bool = False, - base: Tensor[Any] | int | float | ToBeDetermined = TBD, - exponent: Tensor[Any] | int | float | ToBeDetermined = TBD, + base: Tensor[int | float | bool] | int | float | bool | ToBeDetermined = TBD, + exponent: Tensor[int | float | bool] + | int + | float + | bool + | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -237,7 +241,7 @@ def __init__( super().__init__( formula_key="robust_power", name=name, - output=BaseKey(shape=[("out", ...)], type=Tensor), + output=BaseKey(shape=[("out", ...)], type=Tensor[int | float]), base=BaseKey(shape=[("base", ...)], type=Tensor, value=base), exponent=BaseKey(shape=[("exp", ...)], type=Tensor, value=exponent), threshold=BaseKey(shape=[], type=Tensor), @@ -252,9 +256,13 @@ def __init__( super().__init__( formula_key="power", name=name, - output=BaseKey(), - base=BaseKey(value=base), - exponent=BaseKey(value=exponent), + output=BaseKey(type=Tensor[int | float] | int | float), + base=BaseKey( + type=Tensor[int | float | bool] | int | float | bool, value=base + ), + exponent=BaseKey( + type=Tensor[int | float | bool] | int | float | bool, value=exponent + ), ) self._set_constraint( fn=edge_type_constraint, @@ -291,8 +299,8 @@ def __call__( # type: ignore[override] class Add(ArithmeticOperation): def __init__( self, - left: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, - right: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -302,8 +310,8 @@ def __init__( class Subtract(ArithmeticOperation): def __init__( self, - left: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, - right: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -313,8 +321,8 @@ def __init__( class Multiply(ArithmeticOperation): def __init__( self, - left: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, - right: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -350,8 +358,10 @@ class Divide(PrimitiveModel): def __init__( self, - numerator: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, - denominator: Tensor[Any] | ScalarValueType | ToBeDetermined = TBD, + numerator: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + denominator: Tensor[int | float | bool] + | ScalarValueType + | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -387,15 +397,15 @@ class FloorDivide(PrimitiveModel): # TODO: Torch does not accept bool type inputs while JAX and other accepts! def __init__( self, - numerator: Tensor[Any] | ToBeDetermined = TBD, - denominator: Tensor[Any] | ToBeDetermined = TBD, + numerator: Tensor[int | float | bool] | ToBeDetermined = TBD, + denominator: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: super().__init__( formula_key="floor_divide", name=name, - output=BaseKey(shape=[("Var_out", ...)], type=Tensor), + output=BaseKey(shape=[("Var_out", ...)], type=Tensor[int | float]), numerator=BaseKey(shape=[("Var_1", ...)], type=Tensor, value=numerator), denominator=BaseKey(shape=[("Var_2", ...)], type=Tensor, value=denominator), ) @@ -426,8 +436,8 @@ class MatrixMultiply(PrimitiveModel): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -460,7 +470,10 @@ class Shape(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="shape", @@ -484,7 +497,7 @@ class Reshape(PrimitiveModel): def __init__( self, shape: tuple[int | None, ...] | list[int] | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -517,7 +530,10 @@ class Length(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="length", @@ -562,7 +578,10 @@ class Dtype(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="dtype", @@ -585,7 +604,7 @@ class Size(PrimitiveModel): def __init__( self, dim: int | tuple[int, ...] | None | ToBeDetermined = None, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -613,7 +632,10 @@ class Item(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="item", @@ -702,7 +724,10 @@ class TensorToList(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="tensor_to_list", @@ -736,7 +761,7 @@ def __init__( formula_key: str, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool | ToBeDetermined = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, **kwargs: BaseKey, @@ -783,7 +808,7 @@ def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool | ToBeDetermined = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -802,7 +827,7 @@ def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -819,7 +844,7 @@ def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -836,7 +861,7 @@ def __init__( self, axis: int | None | ToBeDetermined = None, keepdim: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -856,7 +881,7 @@ def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -873,7 +898,7 @@ def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -893,7 +918,7 @@ def __init__( self, axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -917,7 +942,7 @@ def __init__( axis: int | tuple[int, ...] | None | ToBeDetermined = None, keepdim: bool = False, correction: int | float | None = 0.0, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -958,7 +983,7 @@ def __init__( self, formula_key: str, polymorphic_constraint: bool = True, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, **kwargs: BaseKey, @@ -985,21 +1010,30 @@ def __call__( # type: ignore[override] class Absolute(SingleInputOperation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(formula_key="abs", name=name, input=input) class Minus(SingleInputOperation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(formula_key="minus", name=name, input=input) class Exponential(SingleInputOperation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="exp", @@ -1017,9 +1051,9 @@ class Sqrt(PrimitiveModel): def __init__( self, robust: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, - cutoff: Tensor[Any] | ToBeDetermined = TBD, + cutoff: Tensor[int | float | bool] | ToBeDetermined = TBD, name: str | None = None, ) -> None: self.robust = robust @@ -1074,8 +1108,8 @@ class RelationalOperators(PrimitiveModel): def __init__( self, formula_key: str, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1105,8 +1139,8 @@ def __call__( # type: ignore[override] class Greater(RelationalOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1116,8 +1150,8 @@ def __init__( class Less(RelationalOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1127,8 +1161,8 @@ def __init__( class Equal(RelationalOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1138,8 +1172,8 @@ def __init__( class NotEqual(RelationalOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1149,8 +1183,8 @@ def __init__( class LessEqual(RelationalOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1160,8 +1194,8 @@ def __init__( class GreaterEqual(RelationalOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ScalarValueType | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1173,7 +1207,10 @@ class LogicalNot(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="logical_not", @@ -1196,8 +1233,8 @@ class BitwiseOperators(PrimitiveModel): def __init__( self, formula_key: str, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1222,8 +1259,8 @@ def __call__( # type: ignore[override] class LogicalAnd(BitwiseOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1233,8 +1270,8 @@ def __init__( class LogicalOr(BitwiseOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1244,8 +1281,8 @@ def __init__( class LogicalXOr(BitwiseOperators): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1260,8 +1297,8 @@ class ShiftLeft(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - shift: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + shift: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1291,8 +1328,8 @@ class ShiftRight(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - shift: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + shift: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1325,7 +1362,7 @@ class Transpose(PrimitiveModel): def __init__( self, axes: int | list[int] | tuple[int, ...] | None | ToBeDetermined = None, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1390,7 +1427,7 @@ def __init__( self, split_size: int, # TODO: should we add default for split_size? axis: int = 0, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ): @@ -1463,7 +1500,7 @@ class Indexer(PrimitiveModel): def __init__( self, index: int | ToBeDetermined = TBD, - input: Tensor[Any] | Sequence[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | Sequence[Any] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1504,7 +1541,7 @@ def __call__( # type: ignore[override] class Sine(SingleInputOperation): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1520,7 +1557,7 @@ def __init__( class Cosine(SingleInputOperation): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index c65945e1..60372ea1 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -16,7 +16,7 @@ from collections.abc import Mapping from types import UnionType -from typing import Any, Self +from typing import Any, Self, get_origin from ...utils.utils import OrderedSet, find_dominant_type from ..common import ( @@ -376,7 +376,11 @@ def _add_connection( outer_key = given_connection.name con_obj = None set_value: ( - ToBeDetermined | str | ScalarValueType | Tensor[Any] | NullConnection + ToBeDetermined + | str + | ScalarValueType + | Tensor[int | float | bool] + | NullConnection ) = NOT_GIVEN if given_connection.data.value is not TBD: set_value = given_connection.data.value @@ -687,7 +691,7 @@ def extend( shape_info: dict[str, ShapeTemplateType] = {} type_info: dict[ str, - type | UnionType | ScalarType | type[Tensor[Any]], + type | UnionType | ScalarType | type[Tensor[int | float | bool]], ] = {} submodel_dag: dict[str, ConnectionData] = {} @@ -714,7 +718,7 @@ def extend( con_obj, _updates = self._add_connection(model, local_key, value, updates) updates |= _updates submodel_dag[local_key] = con_obj - if con_obj.metadata.edge_type is Tensor: + if get_origin(con_obj.metadata.edge_type) is Tensor: updates.shape_updates.add(con_obj.metadata) # Replace shape info keys, which are local keys, with global equivalents. diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index df17b4a3..687e324d 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -69,18 +69,11 @@ def __init__( output_data: IOHyperEdge | None = None for key, value in kwargs.items(): if isinstance(value, BaseKey): - if ( - is_generic_tensor := (get_origin(value.type) is Tensor) - ) or value.type is Tensor: - tensor_types = ( - get_args(value.type)[0] - if is_generic_tensor - else int | float | bool - ) + if get_origin(value.type) is Tensor: if not isinstance(tensor := value.value, Tensor): assert isinstance(value.value, ToBeDetermined) tensor = Tensor( - type=tensor_types, + type=get_args(value.type)[0], shape=shapes[key].node, ) edge = IOHyperEdge(value=tensor, interval=value.interval) @@ -186,7 +179,10 @@ def extract_connection_info( # try to find outer key's real name in data_to_key_map outer_key = data_to_key_map.get(key_data, [key]) outer_key = ["'" + key + "'" for key in outer_key] - if key_data.edge_type is not Tensor and key_data.value is not TBD: + if ( + get_origin(key_data.edge_type) is not Tensor + and key_data.value is not TBD + ): # If value of the scalar is determined, write that value directly. outer_key = [str(key_data.value)] conn.extend(outer_key) diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index 7edfef1a..8275670f 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -14,7 +14,7 @@ from collections.abc import Mapping, Sequence from copy import deepcopy -from typing import Any, Generic, TypeGuard +from typing import Any, Generic, TypeGuard, get_origin from ...backends.backend import Backend from ...core import Constant, DataType, Dtype, data_types, epsilon_table @@ -164,7 +164,7 @@ def _set_data_value(self, key: str, data: IOHyperEdge) -> None: if isinstance(value, Constant): value = epsilon_table[self.backend.precision][value] - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: value = self.backend.array(value) elif isinstance(value, Dtype): value = getattr(self.backend, value.name) @@ -224,7 +224,7 @@ def set_shapes( if isinstance(key, Connection): key = key.key assert isinstance(key, str) - if (data := self._all_data[key]).edge_type is not Tensor: + if get_origin((data := self._all_data[key]).edge_type) is not Tensor: raise ValueError("Non-tensor data can not have shape!") assert data.shape is not None updates |= data.shape.set_values(value) @@ -272,7 +272,7 @@ def set_static_keys( raise KeyError( "Requires static key to be in the input keys of the model!" ) - if (self._all_data[key].edge_type is Tensor) and not isinstance( + if (get_origin(self._all_data[key].edge_type) is Tensor) and not isinstance( value, ToBeDetermined | self.backend.get_backend_array_type() ): raise ValueError( diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 937b8d9b..38828602 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -20,6 +20,7 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass +from typing import get_origin from ...backends.backend import Backend, ParallelBackend from ...core import DataType, GenericDataType @@ -46,6 +47,7 @@ Updates, Variadic, create_shape_map, + find_intersection_type, get_shapes, get_summary, get_summary_shapes, @@ -54,7 +56,7 @@ from ..logical.base import BaseModel from ..logical.model import Model from ..logical.primitive import PrimitiveModel -from ..utils import define_unique_names, find_intersection_type +from ..utils import define_unique_names from .data_store import StaticDataStore from .flat_graph import FlatGraph @@ -200,13 +202,17 @@ def __init__( # TODO: Create an API for setting differentiability of a tensor. physical_data.differentiable = False elif global_key in self._trainable_tensor_inputs: - if physical_data.edge_type not in (Tensor, ToBeDetermined): + # if physical_data.edge_type not in (Tensor, ToBeDetermined): + if not ( + get_origin(physical_data.edge_type) is Tensor + or physical_data.edge_type is ToBeDetermined + ): raise ValueError( f"Non-tensor type data can not be trainable: {global_key}" ) elif physical_data.edge_type is ToBeDetermined: # Set physical data type to Tensor. - updates |= physical_data.set_type(Tensor) + updates |= physical_data.set_type(Tensor[float]) elif physical_data.value is not TBD: raise ValueError( f"Valued data can not be trainable: {global_key}" @@ -218,7 +224,7 @@ def __init__( if key_shape := model_shapes.get(key): data = model_data[key] - assert data.edge_type is Tensor + assert get_origin(data.edge_type) is Tensor shp = data.shape assert shp is not None # assert shp is not None @@ -429,7 +435,7 @@ def _infer_differentiability(self, model_data: dict[str, IOHyperEdge]) -> None: # that have a Tensor type output. output_key = PrimitiveModel.output_key output_edge = model_data[output_key] - if output_edge.edge_type is Tensor: + if get_origin(output_edge.edge_type) is Tensor: # If any of the inputs are differentiable, then # the output is also differentiable. for key, value in model_data.items(): @@ -546,7 +552,7 @@ def _pre_compile( for value in self.data_store.intermediate_non_differentiables.inverse: # there can exist some inferred intermediate scalar keys in logical model. # find those keys and add to cached datas - if (value.edge_type is not Tensor) and (value.value is not TBD): + if (get_origin(value.edge_type) is not Tensor) and (value.value is not TBD): updates.add(value) self.data_store.update_cached_data(updates) @@ -601,7 +607,7 @@ def _pre_compile( # but not unnecessary in flat_graph. This case should be handled when # flat_graph - data_store integration is updated. if conn_edge is not None and ( - (conn_edge.edge_type is not Tensor) + (get_origin(conn_edge.edge_type) is not Tensor) or ( (not find_intersection_type(float, conn_edge.value_type)) or _key @@ -954,7 +960,7 @@ def extract_connection_info( # model. Indicate it accordingly input_name = "'" + connection.key + "'" input_data = model.conns.all[input_key].metadata - if input_data.edge_type is not Tensor: + if get_origin(input_data.edge_type) is not Tensor: # If value of the scalar is determined, write that value pm_input_data = self.data_store.data_memo[id(input_data)] if (val := pm_input_data.value) is not TBD: diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 402c096b..82b053d7 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable from functools import reduce from itertools import product from types import FunctionType, GenericAlias, UnionType @@ -230,156 +230,6 @@ def find_list_depth(arg_type: type | UnionType | GenericAlias) -> int: return max_depth -def find_intersection_type( - type_1: type | UnionType | GenericAlias, - type_2: type | UnionType | GenericAlias, -) -> type | UnionType | None: - # First find direct intersections. - subtypes_1 = set(type_1.__args__) if type(type_1) is UnionType else {type_1} - subtypes_2 = set(type_2.__args__) if type(type_2) is UnionType else {type_2} - intersect = subtypes_1 & subtypes_2 - - # Handle coercion of Any (typing.Any) type to all other types. - if Any in subtypes_1: - intersect.update(subtypes_2) - subtypes_1.remove(Any) - if Any in subtypes_2: - intersect.update(subtypes_1) - subtypes_2.remove(Any) - - # if one of the subtypes have list or tuple without an origin (without square - # brackets, ex: tuple), look for other set if it contains corresponding type - # with origin (ex: tuple[int, int]) if the set contains it, add that type with - # origin (since it contains more information) - - for s_types in (subtypes_1, subtypes_2): - other_set = subtypes_2 if s_types == subtypes_1 else subtypes_1 - for orig_type in (list, tuple, range): - if orig_type in s_types: - for typ in other_set: - if isinstance(typ, GenericAlias): - if typ.__origin__ == orig_type: - intersect.add(typ) - elif typ.__origin__ == Sequence: - if orig_type is range: - if find_intersection_type(int, typ.__args__[0]): - intersect.add(range) - else: - intersect.add( - orig_type[reduce(lambda x, y: x | y, typ.__args__)] # type: ignore - ) - - # Take tuple types from remaining sets and find intesection types - # of all consistent pairs of cartesian product. - for typ_1 in subtypes_1.difference(intersect): - if not isinstance(typ_1, GenericAlias): - continue - - args_1 = typ_1.__args__ - assert ( - typ_1.__origin__ is tuple - or typ_1.__origin__ is list - or typ_1.__origin__ is Sequence - or typ_1.__origin__ is dict - ) - for typ_2 in subtypes_2.difference(intersect): - if not isinstance(typ_2, GenericAlias): - continue - - args_2 = typ_2.__args__ - assert ( - typ_2.__origin__ is tuple - or typ_2.__origin__ is list - or typ_2.__origin__ is dict - or typ_2.__origin__ is Sequence - ) - if typ_1.__origin__ == typ_2.__origin__: - if len(args_1) == 0 or len(args_2) == 0: - # if one of the lengths of the args_1 and args_2 are zero, - # this means one of the types with origin are empty list or tuple, - # in that case, take the empty one (tuple[()], or list[()]) as - # intersection type - common: Any = typ_1.__origin__[()] # type: ignore - - elif typ_1.__origin__ is tuple: - ellipsis_1 = ... in args_1 - ellipsis_2 = ... in args_2 - common = False - if ellipsis_1 and ellipsis_2: - common = find_intersection_type(args_1[0], args_2[0]) - if common: - common = [common, ...] - elif ellipsis_1: - # Remove ellipsis and replace it with base type - # as many times as length of args_2 - common = [ - find_intersection_type(args_1[0], args_2[i]) - for i in range(len(args_2)) - ] - elif ellipsis_2: - # Remove ellipsis and replace it with base type - # as many times as length of args_1 - common = [ - find_intersection_type(args_1[i], args_2[0]) - for i in range(len(args_1)) - ] - elif len(args_1) == len(args_2): - common = [ - find_intersection_type(args_1[i], args_2[i]) - for i in range(len(args_1)) - ] - if common and None not in common: - intersect.add(tuple[*common]) - - elif typ_1.__origin__ is list: - if len(args_2) > 1 or len(args_1) > 1: - raise TypeError( - "args of type list cannot take more than 1 element" - ) - else: - common = find_intersection_type(args_1[0], args_2[0]) - if common: - intersect.add(list[common]) - # TODO: Below code is duplicate of above code, refactor it. - elif typ_1.__origin__ is Sequence: - if len(args_2) > 1 or len(args_1) > 1: - raise TypeError( - "args of type Sequence cannot take more than 1 element" - ) - else: - common = find_intersection_type(args_1[0], args_2[0]) - if common: - intersect.add(Sequence[common]) - - elif Sequence in (typ_1.__origin__, typ_2.__origin__): - if typ_1.__origin__ == Sequence: - coerced_type = typ_1 - other_type = typ_2 - else: - coerced_type = typ_2 - other_type = typ_1 - - other_origin = other_type.__origin__ - assert isinstance(other_origin, type(list) | type(tuple)) - - # Replace Sequence with other origin type and resend them - # to find_intersection_type. - inner_args = reduce(lambda x, y: x | y, coerced_type.__args__) - updated_type = ( - other_origin[inner_args] - if other_type.__origin__ is list - else other_origin[inner_args, ...] - ) - common = find_intersection_type(updated_type, other_type) - if common: - intersect.add(common) - - if intersect: - result = reduce(lambda x, y: x | y, intersect) - return result - return None - - def find_type[T](connection: T) -> type[T]: if isinstance(connection, tuple | list): element_types: list[Any] = [find_type(elem) for elem in connection] diff --git a/mithril/models/models.py b/mithril/models/models.py index 2b61c432..1ba2a458 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -18,7 +18,6 @@ from abc import abstractmethod from collections.abc import Sequence from copy import deepcopy -from typing import Any from ..framework import Model from ..framework.common import ( @@ -31,6 +30,8 @@ ShapeTemplateType, Tensor, ToBeDetermined, + TypeVarTensorType, + ValType, ) from ..framework.constraints import polynomial_kernel_constraint from ..framework.logical.base import BaseModel, ExtendInfo @@ -165,7 +166,7 @@ def __init__( stride: int | None | ToBeDetermined = None, padding: int | PaddingType | tuple[int, int] | ToBeDetermined = (0, 0), dilation: int | ToBeDetermined = 1, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -257,7 +258,7 @@ def __init__( stride: int | None | tuple[int, int] | ToBeDetermined = None, padding: int | PaddingType | tuple[int, int] | ToBeDetermined = (0, 0), dilation: int | ToBeDetermined = 1, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -357,8 +358,8 @@ def __init__( padding: int | PaddingType | tuple[int, int] | ToBeDetermined = 0, dilation: int | ToBeDetermined = 1, use_bias: bool = True, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -439,8 +440,8 @@ def __init__( | ToBeDetermined = (0, 0), dilation: int | tuple[int, int] | ToBeDetermined = (1, 1), use_bias: bool = True, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -524,9 +525,9 @@ def __init__( self, dimension: int | None = None, use_bias: bool = True, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -545,7 +546,7 @@ def __init__( input_key = IOKey(name="input", value=input) weight_key = IOKey(name="weight", value=weight).transpose() if use_bias: - bias_key = IOKey(name="bias", value=bias, type=Tensor) + bias_key = IOKey(name="bias", value=bias, type=Tensor[ValType]) self |= mult(left=input_key, right=weight_key) self |= Add()(left=mult.output, right=bias_key, output=output) shapes["bias"] = [dim] @@ -583,9 +584,9 @@ class ElementWiseAffine(Model): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -631,9 +632,9 @@ def __init__( self, activation: BaseModel, dimension: int | None = None, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -675,9 +676,9 @@ def __init__( use_scale: bool = True, use_bias: bool = True, eps: float = 1e-5, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -687,10 +688,10 @@ def __init__( # Expects its input shape as [B, ..., d] d refers to normalized dimension mean = Mean(axis=-1, keepdim=True) numerator = Subtract() - numerator.set_types(left=Tensor, right=Tensor) + numerator.set_types(left=Tensor[ValType], right=Tensor[ValType]) var = Variance(axis=-1, correction=0, keepdim=True) add = Add() - add.set_types(left=Tensor) + add.set_types(left=Tensor[ValType]) denominator = Sqrt() in_key = IOKey("input", value=input) self += mean(input=in_key) @@ -710,13 +711,13 @@ def __init__( if use_scale: mult = Multiply() - mult.set_types(left=Tensor, right=Tensor) + mult.set_types(left=Tensor[ValType], right=Tensor[ValType]) self += mult(left=self.cout, right=IOKey("weight", value=weight)) mult._set_shapes(shapes) if use_bias: add = Add() - add.set_types(left=Tensor, right=Tensor) + add.set_types(left=Tensor[ValType], right=Tensor[ValType]) self += add(left=self.cout, right=IOKey("bias", value=bias)) add._set_shapes(shapes) # TODO: Remove below Buffer after required naming-related changes are done. @@ -757,10 +758,10 @@ def __init__( use_scale: bool = True, use_bias: bool = True, eps: float = 1e-5, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, name: str | None = None, ) -> None: super().__init__(name=name) @@ -830,7 +831,10 @@ class L1(Model): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(name=name) @@ -857,7 +861,10 @@ class L2(Model): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(name=name) square = Square() @@ -888,8 +895,8 @@ class QuadraticFormRegularizer(Model): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - kernel: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + kernel: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -933,10 +940,10 @@ class RBFKernel(Model): def __init__( self, - input1: Tensor[Any] | ToBeDetermined = TBD, - input2: Tensor[Any] | ToBeDetermined = TBD, - l_scale: Tensor[Any] | ToBeDetermined = TBD, - sigma: Tensor[Any] | ToBeDetermined = TBD, + input1: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input2: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + l_scale: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + sigma: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1011,10 +1018,10 @@ class PolynomialKernel(Model): def __init__( self, robust: bool = True, - input1: Tensor[Any] | ToBeDetermined = TBD, - input2: Tensor[Any] | ToBeDetermined = TBD, - poly_coef: Tensor[Any] | ToBeDetermined = TBD, - degree: Tensor[Any] | ToBeDetermined = TBD, + input1: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input2: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + poly_coef: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + degree: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1076,11 +1083,11 @@ class KernelizedSVM(Model): def __init__( self, kernel: BaseModel, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[Any] | ToBeDetermined, + **kwargs: Tensor[TypeVarTensorType] | ToBeDetermined, ) -> None: if len(kernel.input_keys) < 2: raise KeyError("Kernel requires at least two inputs!") @@ -1147,9 +1154,9 @@ class LinearSVM(Model): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1198,9 +1205,9 @@ class LogisticRegression(Model): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1249,10 +1256,10 @@ def __init__( activations: list[BaseModel], dimensions: Sequence[int | None], input_name_templates: dict[str, str] | None = None, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, - **weights_biases: Tensor[Any] | ToBeDetermined, + **weights_biases: Tensor[TypeVarTensorType] | ToBeDetermined, ) -> None: super().__init__(name=name) self.factory_args = {"activations": activations, "dimensions": dimensions} @@ -1357,12 +1364,12 @@ class RNNCell(Cell): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - w_ih: Tensor[Any] | ToBeDetermined = TBD, - w_hh: Tensor[Any] | ToBeDetermined = TBD, - w_ho: Tensor[Any] | ToBeDetermined = TBD, - bias_h: Tensor[Any] | ToBeDetermined = TBD, - bias_o: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_ih: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_hh: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_ho: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_h: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1484,17 +1491,17 @@ class LSTMCell(Cell): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - w_i: Tensor[Any] | ToBeDetermined = TBD, - w_f: Tensor[Any] | ToBeDetermined = TBD, - w_c: Tensor[Any] | ToBeDetermined = TBD, - w_o: Tensor[Any] | ToBeDetermined = TBD, - w_out: Tensor[Any] | ToBeDetermined = TBD, - bias_f: Tensor[Any] | ToBeDetermined = TBD, - bias_i: Tensor[Any] | ToBeDetermined = TBD, - bias_c: Tensor[Any] | ToBeDetermined = TBD, - bias_o: Tensor[Any] | ToBeDetermined = TBD, - bias_out: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_out: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_out: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1655,17 +1662,17 @@ class LSTMCellBody(Model): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - prev_hidden: Tensor[Any] | ToBeDetermined = TBD, - prev_cell: Tensor[Any] | ToBeDetermined = TBD, - w_i: Tensor[Any] | ToBeDetermined = TBD, - w_f: Tensor[Any] | ToBeDetermined = TBD, - w_c: Tensor[Any] | ToBeDetermined = TBD, - w_o: Tensor[Any] | ToBeDetermined = TBD, - bias_f: Tensor[Any] | ToBeDetermined = TBD, - bias_i: Tensor[Any] | ToBeDetermined = TBD, - bias_c: Tensor[Any] | ToBeDetermined = TBD, - bias_o: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + prev_hidden: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + prev_cell: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + w_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1788,7 +1795,7 @@ def __init__( cell_type: Cell, *, name: str | None = None, - # **kwargs: Tensor[Any] | MainValueType, + # **kwargs: Tensor[TypeVarTensorType] | MainValueType, ) -> None: self.cell_type = cell_type super().__init__(name=name) @@ -1806,10 +1813,10 @@ def __init__( cell_type: Cell, max_sequence_length: int, teacher_forcing: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[Any] | MainValueType, + **kwargs: Tensor[TypeVarTensorType] | MainValueType, ) -> None: super().__init__(cell_type=cell_type, name=name) @@ -1891,10 +1898,10 @@ def __init__( self, cell_type: Cell, max_sequence_length: int, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[Any] | ToBeDetermined, + **kwargs: Tensor[TypeVarTensorType] | ToBeDetermined, ) -> None: super().__init__(cell_type=cell_type, name=name) @@ -1948,10 +1955,10 @@ def __init__( self, cell_type: Cell, max_sequence_length: int, - hidden_concat: Tensor[Any] | ToBeDetermined = TBD, + hidden_concat: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[Any] | ToBeDetermined, + **kwargs: Tensor[TypeVarTensorType] | ToBeDetermined, ) -> None: super().__init__(cell_type, name=name) @@ -2031,7 +2038,7 @@ def __init__( max_input_sequence_length: int, max_target_sequence_length: int, teacher_forcing: bool = False, - indices: Tensor[Any] | ToBeDetermined = TBD, + indices: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2134,8 +2141,8 @@ def __init__( self, get_final_distance: bool = True, robust: bool = True, - input1: Tensor[Any] | ToBeDetermined = TBD, - input2: Tensor[Any] | ToBeDetermined = TBD, + input1: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input2: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2194,9 +2201,9 @@ def __init__( self, degree: int, dimension: int | None = None, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, - bias: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2238,9 +2245,9 @@ def __init__( self, exact_distances: bool = True, robust: bool = True, - distances: Tensor[Any] | ToBeDetermined = TBD, - pred_distances: Tensor[Any] | ToBeDetermined = TBD, - norm: Tensor[Any] | ToBeDetermined = TBD, + distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred_distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2338,9 +2345,9 @@ def __init__( exact_distances: bool = True, calculate_p_joint: bool = False, perplexity: float = 20.0, - distances: Tensor[Any] | ToBeDetermined = TBD, - pred_distances: Tensor[Any] | ToBeDetermined = TBD, - p_joint: Tensor[Any] | ToBeDetermined = TBD, + distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred_distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + p_joint: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2435,10 +2442,10 @@ def __init__( self, base_model: MDSCore | TSNECore, input_type: str = "distances", - input: Tensor[Any] | ToBeDetermined = TBD, - coords: Tensor[Any] | ToBeDetermined = TBD, - norm: Tensor[Any] | ToBeDetermined = TBD, - predicted_coords: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + predicted_coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2544,10 +2551,10 @@ def __init__( self, prediction_dim: int, input_type: str = "distances", - input: Tensor[Any] | ToBeDetermined = TBD, - coords: Tensor[Any] | ToBeDetermined = TBD, - norm: Tensor[Any] | ToBeDetermined = TBD, - predicted_coords: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + predicted_coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2602,9 +2609,9 @@ def __init__( input_type: str = "distances", preplexity: float = 20.0, calculate_p_joint: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, - norm: Tensor[Any] | ToBeDetermined = TBD, - predicted_coords: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + predicted_coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2658,14 +2665,14 @@ class GaussProcessRegressionCore(Model): def __init__( self, - s: Tensor[Any] | ToBeDetermined = TBD, - k: Tensor[Any] | ToBeDetermined = TBD, - k_star: Tensor[Any] | ToBeDetermined = TBD, - mu: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, - loss: Tensor[Any] | ToBeDetermined = TBD, - prediction: Tensor[Any] | ToBeDetermined = TBD, - confidence: Tensor[Any] | ToBeDetermined = TBD, + s: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + k: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + k_star: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + mu: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + loss: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + prediction: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + confidence: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2785,11 +2792,11 @@ class GPRLoss(Model): def __init__( self, robust: bool = False, - labels: Tensor[Any] | ToBeDetermined = TBD, - mu: Tensor[Any] | ToBeDetermined = TBD, - L: Tensor[Any] | ToBeDetermined = TBD, - K_term: Tensor[Any] | ToBeDetermined = TBD, - alpha: Tensor[Any] | ToBeDetermined = TBD, + labels: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + mu: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + L: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + K_term: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + alpha: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2877,8 +2884,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2952,8 +2959,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3014,8 +3021,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3175,8 +3182,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3335,8 +3342,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3491,8 +3498,8 @@ def __init__( self, n_classes: int, is_label_one_hot: bool = True, - pred: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3548,7 +3555,10 @@ class SiLU(Model): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(name=name) diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 1aee39b2..26aa9856 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -15,7 +15,6 @@ from __future__ import annotations from types import NoneType -from typing import Any from ..core import Constant, Dtype from ..framework.common import ( @@ -152,8 +151,8 @@ def __init__( formula_key: str, polymorphic_constraint: bool = True, name: str | None = None, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, **kwargs: BaseKey, ) -> None: default_kwargs: dict[str, BaseKey] = { @@ -193,8 +192,8 @@ def __call__( # type: ignore[override] class SquaredError(SupervisedLoss): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -206,8 +205,8 @@ def __init__( class AbsoluteError(SupervisedLoss): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -219,8 +218,8 @@ def __init__( class HingeLoss(SupervisedLoss): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -237,8 +236,8 @@ def __init__( class QuadHingeLoss(SupervisedLoss): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -265,8 +264,8 @@ class QuantileLoss(PrimitiveModel): def __init__( self, quantile: int | float | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -325,8 +324,8 @@ def __init__( self, input_type: str = "logits", weights: list[float] | str = "", - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, robust: bool | ToBeDetermined = TBD, cutoff: ConstantType | ToBeDetermined = TBD, *, @@ -427,8 +426,8 @@ class KLDivergence(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, cutoff: ConstantType | ToBeDetermined = TBD, name: str | None = None, ) -> None: @@ -482,9 +481,9 @@ def __init__( self, input_type: str = "logits", pos_weight: float | str | ToBeDetermined = 1.0, - input: Tensor[Any] | ToBeDetermined = TBD, - target: Tensor[Any] | ToBeDetermined = TBD, - cutoff: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + target: Tensor[int | float | bool] | ToBeDetermined = TBD, + cutoff: Tensor[int | float | bool] | ToBeDetermined = TBD, robust: bool | ToBeDetermined = TBD, *, name: str | None = None, @@ -556,9 +555,9 @@ class Log(PrimitiveModel): def __init__( self, robust: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, - cutoff: Tensor[Any] | ToBeDetermined = TBD, + cutoff: Tensor[int | float | bool] | ToBeDetermined = TBD, name: str | None = None, ) -> None: self.robust = robust @@ -611,7 +610,7 @@ class StableReciprocal(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, cutoff: Tensor[int | float] | ToBeDetermined = TBD, *, name: str | None = None, @@ -642,7 +641,10 @@ def __call__( # type: ignore[override] class Sign(SingleInputOperation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="sign", @@ -655,7 +657,10 @@ def __init__( class Square(SingleInputOperation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(formula_key="square", name=name, input=input) @@ -669,7 +674,7 @@ def __init__( self, formula_key: str, polymorphic_constraint: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, **kwargs: BaseKey, @@ -700,7 +705,10 @@ def __call__( # type: ignore[override] class Relu(Activation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="relu", @@ -715,7 +723,7 @@ class Gelu(Activation): def __init__( self, approximate: bool = False, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -739,7 +747,10 @@ def __call__( # type: ignore[override] class Sigmoid(Activation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(formula_key="sigmoid", name=name, input=input) @@ -747,7 +758,7 @@ def __init__( class Softmax(Activation): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, axis: int | None | ToBeDetermined = TBD, *, name: str | None = None, @@ -766,14 +777,20 @@ def __call__( # type: ignore[override] class Softplus(Activation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(formula_key="softplus", name=name, input=input) class Tanh(Activation): def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__(formula_key="tanh", name=name, input=input) @@ -785,8 +802,8 @@ class LeakyRelu(Activation): def __init__( self, - slope: Tensor[Any] | int | float | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, + slope: Tensor[int | float | bool] | int | float | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -811,7 +828,10 @@ class StopGradient(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="stop_gradient", @@ -833,8 +853,8 @@ class CartesianDifference(PrimitiveModel): def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -869,7 +889,7 @@ def __init__( axis: int | None | ToBeDetermined = 0, *, name: str | None = None, - **kwargs: Tensor[Any] | ToBeDetermined, + **kwargs: Tensor[int | float | bool] | ToBeDetermined, ) -> None: self.factory_args = {"n": n, "axis": axis} @@ -930,8 +950,8 @@ class PermuteTensor(PrimitiveModel): def __init__( self, - indices: Tensor[Any] | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, + indices: Tensor[int | float | bool] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -969,13 +989,13 @@ class PrimitiveConvolution1D(PrimitiveModel): def __init__( self, use_bias: bool = True, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, stride: int | ToBeDetermined = TBD, padding: int | tuple[int, int] | ToBeDetermined = TBD, dilation: int | ToBeDetermined = TBD, *, - bias: Tensor[Any] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, name: str | None = None, ) -> None: self.factory_args = {"use_bias": use_bias} @@ -1055,8 +1075,8 @@ class PrimitiveConvolution2D(PrimitiveModel): def __init__( self, use_bias: bool = True, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, stride: int | tuple[int, int] | ToBeDetermined = TBD, padding: int | tuple[int, int] @@ -1064,7 +1084,7 @@ def __init__( | ToBeDetermined = TBD, dilation: int | tuple[int, int] | ToBeDetermined = TBD, *, - bias: Tensor[Any] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, name: str | None = None, ) -> None: self.factory_args = {"use_bias": use_bias} @@ -1146,7 +1166,7 @@ def __init__( self, start_dim: int | ToBeDetermined = 0, end_dim: int | ToBeDetermined = -1, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1190,7 +1210,7 @@ class PrimitiveMaxPool1D(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, kernel_size: int | ToBeDetermined = TBD, stride: int | ToBeDetermined = TBD, padding: int | tuple[int, int] | ToBeDetermined = TBD, @@ -1398,7 +1418,7 @@ class PrimitiveMaxPool2D(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, kernel_size: int | tuple[int, int] | ToBeDetermined = TBD, stride: int | tuple[int, int] | ToBeDetermined = TBD, padding: int @@ -1477,7 +1497,10 @@ class NormModifier(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="norm_modifier", @@ -1505,8 +1528,8 @@ class DistanceMatrix(PrimitiveModel): # TODO: torch.cdist handles batches of matrices, for now we don't. def __init__( self, - left: Tensor[Any] | ToBeDetermined = TBD, - right: Tensor[Any] | ToBeDetermined = TBD, + left: Tensor[int | float | bool] | ToBeDetermined = TBD, + right: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1542,7 +1565,7 @@ class PolynomialFeatures(PrimitiveModel): def __init__( self, degree: int | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1580,9 +1603,9 @@ class TsnePJoint(PrimitiveModel): def __init__( self, - squared_distances: Tensor[Any] | ToBeDetermined = TBD, + squared_distances: Tensor[int | float | bool] | ToBeDetermined = TBD, target_perplexity: float | ToBeDetermined = TBD, - threshold: Tensor[Any] | ToBeDetermined = TBD, + threshold: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1687,7 +1710,10 @@ class Cholesky(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="cholesky", @@ -1710,9 +1736,9 @@ class GPRAlpha(PrimitiveModel): def __init__( self, - label_mu_diff: Tensor[Any] | ToBeDetermined = TBD, - L: Tensor[Any] | ToBeDetermined = TBD, - K_term: Tensor[Any] | ToBeDetermined = TBD, + label_mu_diff: Tensor[int | float | bool] | ToBeDetermined = TBD, + L: Tensor[int | float | bool] | ToBeDetermined = TBD, + K_term: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1745,9 +1771,9 @@ class GPRVOuter(PrimitiveModel): def __init__( self, - K: Tensor[Any] | ToBeDetermined = TBD, - K_term: Tensor[Any] | ToBeDetermined = TBD, - L: Tensor[Any] | ToBeDetermined = TBD, + K: Tensor[int | float | bool] | ToBeDetermined = TBD, + K_term: Tensor[int | float | bool] | ToBeDetermined = TBD, + L: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1775,7 +1801,10 @@ class TransposedDiagonal(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="transposed_diag", @@ -1905,7 +1934,7 @@ class BroadcastTo(PrimitiveModel): def __init__( self, shape: tuple[int, ...] | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1941,8 +1970,8 @@ class Eigvalsh(PrimitiveModel): def __init__( self, - K_term: Tensor[Any] | ToBeDetermined = TBD, - L: Tensor[Any] | ToBeDetermined = TBD, + K_term: Tensor[int | float | bool] | ToBeDetermined = TBD, + L: Tensor[int | float | bool] | ToBeDetermined = TBD, threshold: ConstantType | ToBeDetermined = TBD, *, name: str | None = None, @@ -1971,7 +2000,10 @@ class Squeeze(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="squeeze", @@ -1998,8 +2030,8 @@ class AUCCore(PrimitiveModel): def __init__( self, - input: Tensor[Any] | ToBeDetermined = TBD, - label: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2029,8 +2061,8 @@ def __init__( self, num_embeddings: int | None = None, dim: int | None = None, - input: Tensor[Any] | ToBeDetermined = TBD, - weight: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2074,10 +2106,10 @@ def __init__( scale: None | int | float | ToBeDetermined = None, dropout_p: float | ToBeDetermined = 0.0, use_attn_mask: bool = False, - query: Tensor[Any] | ToBeDetermined = TBD, - key: Tensor[Any] | ToBeDetermined = TBD, - value: Tensor[Any] | ToBeDetermined = TBD, - attn_mask: Tensor[Any] | ToBeDetermined = TBD, + query: Tensor[int | float | bool] | ToBeDetermined = TBD, + key: Tensor[int | float | bool] | ToBeDetermined = TBD, + value: Tensor[int | float | bool] | ToBeDetermined = TBD, + attn_mask: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2153,7 +2185,7 @@ def __init__( self, hidden_dim: int | ToBeDetermined, max_len: int | ToBeDetermined = 5000, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2190,7 +2222,7 @@ def __init__( self, axis1: int | ToBeDetermined, axis2: int | ToBeDetermined, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2230,9 +2262,9 @@ class Where(PrimitiveModel): def __init__( self, - cond: Tensor[Any] | ToBeDetermined = TBD, - input1: Tensor[Any] | ToBeDetermined = TBD, - input2: Tensor[Any] | ToBeDetermined = TBD, + cond: Tensor[int | float | bool] | ToBeDetermined = TBD, + input1: Tensor[int | float | bool] | ToBeDetermined = TBD, + input2: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2268,7 +2300,10 @@ class IsNan(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="isnan", @@ -2288,7 +2323,10 @@ class Unique(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="unique", @@ -2310,8 +2348,8 @@ class Trapezoid(PrimitiveModel): def __init__( self, - x: Tensor[Any] | ToBeDetermined = TBD, - y: Tensor[Any] | ToBeDetermined = TBD, + x: Tensor[int | float | bool] | ToBeDetermined = TBD, + y: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2344,7 +2382,7 @@ def __init__( nan: float | ToBeDetermined = 0.0, posinf: float | None | ToBeDetermined = None, neginf: float | None | ToBeDetermined = None, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2379,7 +2417,7 @@ class Pad(PrimitiveModel): def __init__( self, pad_width: tuple[tuple[int, int], ...] | ToBeDetermined = TBD, - input: Tensor[Any] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2409,7 +2447,10 @@ class ZerosLike(PrimitiveModel): output: Connection def __init__( - self, input: Tensor[Any] | ToBeDetermined = TBD, *, name: str | None = None + self, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + *, + name: str | None = None, ) -> None: super().__init__( formula_key="zeros_like", diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index 184bc9de..b7d7903e 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -110,7 +110,11 @@ def __init__(self, model: BaseModel) -> None: self.regularization_keys: list[str] = [] self.metric_keys: list[str] = [] self.loss_combiner: BaseModel = Sum() - self.reg_coef_map: dict[float | Tensor[Any], set[Connection]] = {} + # TODO: Update type of reg_coef_map. float | Tensor[int | float | bool] + # is not correct. + self.reg_coef_map: dict[ + float | Tensor[int | float | bool], set[Connection] + ] = {} self.geomean_map: dict[str, list[tuple[Connection, float]]] = {} self.reduce_inputs: dict[str, list[tuple[Connection, Connection]]] = {} diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index e7b711c1..09f540de 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -18,7 +18,7 @@ from collections.abc import Callable, Sequence from copy import deepcopy from types import EllipsisType, UnionType -from typing import Any, TypedDict +from typing import Any, TypedDict, get_origin from ..framework.common import ( TBD, @@ -135,7 +135,7 @@ def create_iokey_kwargs( kwargs["value"] = Tensor(val["tensor"]) if isinstance(val, dict) else val if (typ := info_cpy.get("type")) is not None: # Convert type strings to type objects. - kwargs["type"] = Tensor if typ == "tensor" else eval(typ) + kwargs["type"] = Tensor[int | float | bool] if typ == "tensor" else eval(typ) if (conns := info_cpy.get("connect")) is not None: kwargs["connections"] = { getattr(submodels_dict[value[0]], value[1]) @@ -210,7 +210,7 @@ def dict_to_model( set_types = {} for key, typ in types.items(): if typ == "tensor": - set_types[key] = Tensor + set_types[key] = Tensor[int | float | bool] else: # TODO: Get rid of using eval method. Find more secure # way to convert strings into types and generic types. @@ -290,7 +290,7 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: for key, con in model.conns.all.items(): edge = con.metadata - if edge.edge_type is Tensor and not con.is_key_autogenerated: + if get_origin(edge.edge_type) is Tensor and not con.is_key_autogenerated: differentiablility_info[key] = edge.differentiable for shape in model.assigned_shapes: @@ -300,7 +300,7 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: assigned_constraints.append(constrain) for key, typ in model.assigned_types.items(): - if typ is Tensor: + if get_origin(typ) is Tensor: types[key] = "tensor" else: types[key] = str(typ) @@ -387,7 +387,7 @@ def connection_to_dict( elif is_valued and connection in model.conns.input_connections: val = connection.metadata.value assert not isinstance(val, ToBeDetermined) - if connection.metadata.edge_type is Tensor: + if get_origin(connection.metadata.edge_type) is Tensor: val = {"tensor": val} if connection.key.startswith("$"): key_value = val diff --git a/mithril/utils/func_utils.py b/mithril/utils/func_utils.py index 68aa04ef..31ba3d59 100644 --- a/mithril/utils/func_utils.py +++ b/mithril/utils/func_utils.py @@ -14,7 +14,7 @@ from collections.abc import Callable from copy import deepcopy -from typing import Any +from typing import Any, get_origin from ..core import DataType from ..framework.common import ( # , Scalar, Tensor @@ -162,7 +162,7 @@ def reorganize_args( def is_make_array_required(data: IOHyperEdge) -> bool: - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert data.shape is not None _temp_shape = next(iter(data.shape.reprs)) # It is needed to guarantee that Tensor is at least one dimensional. diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 43b749f5..8c5d2f71 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import get_origin from mithril import Backend, Constant, compile, epsilon_table from mithril.framework.common import IOHyperEdge, Tensor @@ -56,7 +57,7 @@ def evaluate_case( model = finalize_model(current_case) # Convert static keys to array if they are not scalar. for key, value in static_keys.items(): - if model.conns.get_metadata(key).edge_type is not Tensor: + if get_origin(model.conns.get_metadata(key).edge_type) is not Tensor: static_keys[key] = value else: static_keys[key] = convert_to_array(backend, value) @@ -91,7 +92,7 @@ def evaluate_case( data_value = epsilon_table[backend.precision][data_value] assert data_value == copied_data.value - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert id(data.value) == id(copied_data.value) # Evaluate model. diff --git a/tests/scripts/test_constant_inputs.py b/tests/scripts/test_constant_inputs.py index 48ee1f85..9e539312 100644 --- a/tests/scripts/test_constant_inputs.py +++ b/tests/scripts/test_constant_inputs.py @@ -936,7 +936,7 @@ def test_str_axis_set_shapes(): assert str(err_info.value) == ( "Acceptable types are None | int | list[int] | tuple[int, ...], " - "but type value is provided!" + "but type is provided!" ) @@ -947,7 +947,7 @@ def test_float_axis_2(): model1 += mean1(axis=3.0) assert str(err_info.value) == ( "Acceptable types are None | int | list[int] | tuple[int, ...], but " - " type value is provided!" + " type is provided!" ) @@ -957,7 +957,7 @@ def test_float_axis_2_set_values(): mean1.set_values({"axis": 3.0}) assert str(err_info.value) == ( "Acceptable types are None | int | list[int] | tuple[int, ...], but " - " type value is provided!" + " type is provided!" ) @@ -976,8 +976,7 @@ def test_static_type(): model1 += poly_feat_1(input="", degree=conv2d.stride) assert str(err.value) == ( - "Acceptable types are tuple[int, int], but " - "type value is provided!" + "Acceptable types are tuple[int, int], but " "type is provided!" ) @@ -991,8 +990,7 @@ def test_static_type_set_value(): model1 += poly_feat_1(input="", degree=conv2d.stride) assert str(err.value) == ( - "Acceptable types are tuple[int, int], but " - "type value is provided!" + "Acceptable types are tuple[int, int], but " "type is provided!" ) @@ -1584,9 +1582,9 @@ def test_composite_4_set_values(): def test_composite_5(): - list1 = Tensor(np.random.randn(2, 3, 4).tolist()) - list2 = Tensor(np.random.randn(1, 3, 4).tolist()) - list3 = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(2, 3, 4).tolist()) + list2: Tensor[float] = Tensor(np.random.randn(1, 3, 4).tolist()) + list3: Tensor[float] = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) model = Model() add_model_1 = Add() add_model_2 = Add() @@ -1599,9 +1597,9 @@ def test_composite_5(): def test_composite_5_set_values(): - list1 = Tensor(np.random.randn(2, 3, 4).tolist()) - list2 = Tensor(np.random.randn(1, 3, 4).tolist()) - list3 = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(2, 3, 4).tolist()) + list2: Tensor[float] = Tensor(np.random.randn(1, 3, 4).tolist()) + list3: Tensor[float] = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) model = Model() add_model_1 = Add() add_model_2 = Add() @@ -1617,9 +1615,9 @@ def test_composite_5_set_values(): def test_composite_6(): - list1 = Tensor(np.random.randn(2, 3, 4).tolist()) - list2 = Tensor(np.random.randn(1, 3, 4).tolist()) - list3 = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(2, 3, 4).tolist()) + list2: Tensor[float] = Tensor(np.random.randn(1, 3, 4).tolist()) + list3: Tensor[float] = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) model = Model() add_model_1 = Add() add_model_2 = Add() @@ -1631,9 +1629,9 @@ def test_composite_6(): def test_composite_6_set_values(): - list1 = Tensor(np.random.randn(2, 3, 4).tolist()) - list2 = Tensor(np.random.randn(1, 3, 4).tolist()) - list3 = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(2, 3, 4).tolist()) + list2: Tensor[float] = Tensor(np.random.randn(1, 3, 4).tolist()) + list3: Tensor[float] = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) model = Model() add_model_1 = Add() add_model_2 = Add() @@ -1649,9 +1647,9 @@ def test_composite_6_set_values(): def test_composite_7(): - list1 = Tensor(np.random.randn(2, 3, 4).tolist()) - list2 = Tensor(np.random.randn(1, 3, 4).tolist()) - list3 = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(2, 3, 4).tolist()) + list2: Tensor[float] = Tensor(np.random.randn(1, 3, 4).tolist()) + list3: Tensor[float] = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) model = Model() add_model_1 = Add() add_model_2 = Add() @@ -1664,9 +1662,9 @@ def test_composite_7(): def test_composite_7_set_values(): - list1 = Tensor(np.random.randn(2, 3, 4).tolist()) - list2 = Tensor(np.random.randn(1, 3, 4).tolist()) - list3 = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(2, 3, 4).tolist()) + list2: Tensor[float] = Tensor(np.random.randn(1, 3, 4).tolist()) + list3: Tensor[float] = Tensor(np.random.randn(2, 2, 1, 1, 1).tolist()) model = Model() add_model_1 = Add() add_model_2 = Add() @@ -1682,7 +1680,7 @@ def test_composite_7_set_values(): def test_composite_conv_mean(): - list1 = Tensor(np.random.randn(1, 1, 8, 8).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(1, 1, 8, 8).tolist()) model = Model() conv_model = Convolution2D(kernel_size=2, out_channels=1, stride=(2, 3)) reduce_model = Mean(axis=TBD) @@ -1693,7 +1691,7 @@ def test_composite_conv_mean(): def test_composite_conv_mean_set_values(): - list1 = Tensor(np.random.randn(1, 1, 8, 8).tolist()) + list1: Tensor[float] = Tensor(np.random.randn(1, 1, 8, 8).tolist()) model = Model() conv_model = Convolution2D(kernel_size=2, out_channels=1, stride=(2, 3)) reduce_model = Mean(axis=TBD) @@ -1705,7 +1703,7 @@ def test_composite_conv_mean_set_values(): def test_composite_conv_mean_2(): - list1 = Tensor(np.ones((1, 1, 8, 8)).tolist()) + list1: Tensor[float] = Tensor(np.ones((1, 1, 8, 8)).tolist()) model = Model() conv_model = Convolution2D(kernel_size=2, out_channels=1, stride=TBD) reduce_model = Sum(axis=TBD) @@ -1721,7 +1719,7 @@ def test_composite_conv_mean_2(): def test_composite_conv_mean_2_set_values(): - list1 = Tensor(np.ones((1, 1, 8, 8)).tolist()) + list1: Tensor[float] = Tensor(np.ones((1, 1, 8, 8)).tolist()) model = Model() conv_model = Convolution2D(kernel_size=2, out_channels=1, stride=TBD) reduce_model = Sum(axis=TBD) diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index 047cb768..eabc17f8 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -13,6 +13,8 @@ # limitations under the License. +from typing import get_origin + import pytest from mithril.framework.common import ( @@ -48,8 +50,12 @@ def dummy_constraint(output: IOHyperEdge, input: IOHyperEdge): # updated_symbols = set() updates = Updates() status = False - output_repr = output._temp_shape if output.edge_type is Tensor else output.value - input_repr = input._temp_shape if input.edge_type is Tensor else input.value + output_repr = ( + output._temp_shape if get_origin(output.edge_type) is Tensor else output.value + ) + input_repr = ( + input._temp_shape if get_origin(input.edge_type) is Tensor else input.value + ) assert isinstance(output_repr, ShapeRepr) assert isinstance(input_repr, ShapeRepr) if bool(input_repr.root) ^ bool(output_repr.root): @@ -101,7 +107,10 @@ def __init__(self) -> None: super().__init__( formula_key="buffer", input=BaseKey(shape=[("Var1", ...)], type=Tensor), - output=BaseKey(shape=[("Var2", ...)], type=Tensor), + output=BaseKey( + shape=[("Var2", ...)], + type=Tensor, + ), ) self._set_constraint(fn=dummy_constraint, keys=["output", "input"]) diff --git a/tests/scripts/test_constraints.py b/tests/scripts/test_constraints.py index 19f6cfac..b66518e2 100644 --- a/tests/scripts/test_constraints.py +++ b/tests/scripts/test_constraints.py @@ -15,7 +15,7 @@ from collections.abc import Callable, Mapping from copy import deepcopy from types import EllipsisType, NoneType, UnionType -from typing import Any, TypeGuard +from typing import Any, TypeGuard, get_origin import numpy as np import pytest @@ -118,7 +118,9 @@ def shape_map_to_tensor( # Simply converts ShapeRepr objects to Tensor types. tensor_dict = {} for key, value in shape_map.items(): - tensor = Tensor(type=float | int | bool, shape=value.node) + tensor: Tensor[int | float | bool] = Tensor( + type=float | int | bool, shape=value.node + ) edge = IOHyperEdge(value=tensor, key_origin=key) # set temp_shape. Since temp_shape of a Tensor initialized as None in its # constructor. @@ -216,7 +218,7 @@ def assert_shape_results( shapes = {} assignments: AssignmentType = {} for key, value in data.items(): - if value.edge_type is Tensor: + if get_origin(value.edge_type) is Tensor: assert value.shape is not None shapes[key] = value.shape.get_shapes(uni_cache, var_cache, verbose=True) shape_repr = value._temp_shape @@ -262,7 +264,7 @@ def assert_value_results( assert data[key].value == value else: # If value is a tensor of any supported backend. - assert data[key].edge_type is Tensor + assert get_origin(data[key].edge_type) is Tensor d_val = data[key].value assert GenericDataType.is_tensor_type(d_val) assert (d_val == value).all() diff --git a/tests/scripts/test_hyperedge_my_tensor.py b/tests/scripts/test_hyperedge.py similarity index 70% rename from tests/scripts/test_hyperedge_my_tensor.py rename to tests/scripts/test_hyperedge.py index 418b025a..b9e60f56 100644 --- a/tests/scripts/test_hyperedge_my_tensor.py +++ b/tests/scripts/test_hyperedge.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import get_origin + import pytest from mithril.framework.common import ( @@ -31,9 +33,9 @@ def test_init_with_tensor_default_type(): - edge = IOHyperEdge(Tensor) + edge = IOHyperEdge(Tensor[int | float | bool]) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value_type == int | float | bool and edge.value is TBD @@ -45,7 +47,7 @@ def test_init_with_tensor_default_type(): def test_init_with_tensor_int_or_float_type(): edge = IOHyperEdge(Tensor[int | float]) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value_type == int | float and edge.value is TBD @@ -57,7 +59,7 @@ def test_init_with_tensor_int_or_float_type(): def test_init_with_tensor_type_tensor_value(): edge = IOHyperEdge(Tensor[int | float], value=Tensor([[2.0]])) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -83,7 +85,7 @@ def test_init_with_wrong_tensor_type_tensor_value(): assert ( str(err_info.value) == "Acceptable types are bool | int, but type " - "value is provided!" + "is provided!" ) @@ -93,7 +95,7 @@ def test_init_with_wrong_scalar_type_scalar_value(): assert ( str(err_info.value) == "Acceptable types are bool | int, but type " - "value is provided!" + "is provided!" ) @@ -113,9 +115,9 @@ def test_set_tensor_type(): edge = IOHyperEdge() assert edge.edge_type is ToBeDetermined and edge._value is TBD and edge.value is TBD assert edge.shape is None - edge.set_type(Tensor) + edge.set_type(Tensor[int | float | bool]) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value is TBD ) @@ -130,7 +132,7 @@ def test_set_generic_tensor_type(): assert edge.shape is None edge.set_type(Tensor[int | float]) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value is TBD ) @@ -151,15 +153,22 @@ def test_set_scalar_type(): def test_set_scalar_edge_type_to_tensor_type(): edge = IOHyperEdge(type=int | float) with pytest.raises(TypeError) as err_info: - edge.set_type(Tensor) - assert str(err_info.value) == "Can not set Tensor type to a Scalar edge." + edge.set_type(Tensor[int | float | bool]) + assert ( + str(err_info.value) == "Acceptable types are float | int, but " + "mithril.framework.common.Tensor[int | float | bool] type is provided!" + ) def test_set_tensor_edge_type_to_scalar_type(): - edge = IOHyperEdge(type=Tensor) + edge = IOHyperEdge(type=Tensor[int | float | bool]) with pytest.raises(TypeError) as err_info: edge.set_type(int | float) - assert str(err_info.value) == "Can not set Scalar type to a Tensor edge." + assert ( + str(err_info.value) + == "Acceptable types are mithril.framework.common.Tensor[int | float | bool], " + "but float | int type is provided!" + ) ############## Value Setting Tests ############## @@ -167,10 +176,10 @@ def test_set_tensor_edge_type_to_scalar_type(): def test_init_with_tensor_value(): shape_node = ShapeRepr(root=Variadic()).node - tensor = Tensor([[2.0]], shape=shape_node) + tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) edge = IOHyperEdge(value=tensor) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -185,11 +194,11 @@ def test_init_with_tensor_value(): def test_set_non_typed_edge_with_tensor_value(): shape_node = ShapeRepr(root=Variadic()).node - tensor = Tensor([[2.0]], shape=shape_node) + tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) edge = IOHyperEdge() edge.set_value(tensor) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -206,13 +215,13 @@ def test_set_non_typed_edge_with_tensor_value(): def test_set_tensor_edge_with_tensor_value(): shape_node = ShapeRepr(root=Variadic()).node - tensor = Tensor([[2.0]], shape=shape_node) - edge = IOHyperEdge(type=Tensor) + tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) + edge = IOHyperEdge(type=Tensor[int | float | bool]) assert isinstance(edge._value, Tensor) edge_tensor = edge._value edge.set_value(tensor) assert ( - edge.edge_type is Tensor + get_origin(edge.edge_type) is Tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -238,7 +247,7 @@ def test_set_scalar_edge_with_scalar_value(): def test_set_scalar_edge_with_tensor_value(): shape_node = ShapeRepr(root=Variadic()).node - tensor = Tensor([[2.0]], shape=shape_node) + tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) edge = IOHyperEdge(type=int | float) with pytest.raises(ValueError) as err_info: edge.set_value(tensor) @@ -246,7 +255,7 @@ def test_set_scalar_edge_with_tensor_value(): def test_set_tensor_edge_with_scalar_value(): - edge = IOHyperEdge(type=Tensor) + edge = IOHyperEdge(type=Tensor[int | float | bool]) with pytest.raises(ValueError) as err_info: edge.set_value(3) assert str(err_info.value) == "Can not set Scalar value to a Tensor edge." @@ -254,7 +263,7 @@ def test_set_tensor_edge_with_scalar_value(): def test_set_tensor_edge_with_different_tensor_value(): shape_node = ShapeRepr(root=Variadic()).node - tensor = Tensor([[2.0]], shape=shape_node) + tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) edge = IOHyperEdge(value=tensor) with pytest.raises(ValueError) as err_info: edge.set_value(Tensor([[3.0]], shape=ShapeRepr(root=Variadic()).node)) @@ -268,12 +277,12 @@ def test_set_tensor_edge_with_different_type_tensor_value(): edge = IOHyperEdge(type=Tensor[int | bool]) with pytest.raises(TypeError) as err_info: shape_node = ShapeRepr(root=Variadic()).node - tensor = Tensor([[2.0]], shape=shape_node) + tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) edge.set_value(tensor) assert ( str(err_info.value) == "Acceptable types are bool | int, but type " - "value is provided!" + "is provided!" ) @@ -283,7 +292,7 @@ def test_set_scalar_edge_with_different_type_scalar_value(): edge.set_value([1, 2]) assert ( str(err_info.value) == "Acceptable types are bool | int, but list[int] type " - "value is provided!" + "is provided!" ) @@ -328,8 +337,8 @@ def test_match_tensor_edge_with_tensor_edge_with_no_common_types(): edge1.match(edge2) assert ( str(err_info.value) - == "Acceptable types are float | int, but type " - "value is provided!" + == "Acceptable types are mithril.framework.common.Tensor[int | float], " + "but mithril.framework.common.Tensor[bool] type is provided!" ) @@ -339,7 +348,11 @@ def test_match_tensor_edge_with_scalar_edge(): with pytest.raises(TypeError) as err_info: edge1.match(edge2) - assert str(err_info.value) == "Can not set Scalar type to a Tensor edge." + assert ( + str(err_info.value) + == "Acceptable types are mithril.framework.common.Tensor[int | float], " + "but bool | float type is provided!" + ) def test_match_scalar_edge_with_tensor_edge(): @@ -348,7 +361,10 @@ def test_match_scalar_edge_with_tensor_edge(): with pytest.raises(TypeError) as err_info: edge2.match(edge1) - assert str(err_info.value) == "Can not set Tensor type to a Scalar edge." + assert ( + str(err_info.value) == "Acceptable types are bool | float, but " + "mithril.framework.common.Tensor[int | float] type is provided!" + ) def test_match_untyped_edge_with_tensor_edge(): @@ -405,3 +421,79 @@ def test_match_scalar_edge_with_untyped_edge(): assert updates.constraints == set() assert updates.value_updates == set() assert updates.shape_updates == set() + + +def test_match_mixed_type_edge_with_tensor_edge(): + edge1 = IOHyperEdge(type=Tensor[int | float] | int | float) # Mixed type edge. + + constr = Constraint(fn=reduce_type_constraint, type=UpdateType.TYPE) + edge2 = IOHyperEdge(type=Tensor[float | bool]) + edge2.add_constraint(constr) + node2 = edge2.shape + + updates = edge1.match(edge2) + assert edge1.shape is not None + assert edge2.shape is not None + assert ( + isinstance(edge1._value, Tensor) + and edge1._value.referees == {edge1} + and edge1.value_type is float + ) + assert edge1.edge_type == Tensor[float] + assert not edge1.is_polymorphic + assert isinstance(edge2._value, Tensor) and edge2._value.referees == {edge1} + assert edge1.shape is edge2.shape + assert edge1.shape.referees == {edge1} + assert edge1.all_constraints == {constr} and edge2.all_constraints == set() + assert updates.constraints == {constr} + assert updates.value_updates == set() + assert updates.shape_updates == set() + assert updates.node_updates == {node2} # NOTE: The shape node is useless actually. + + +def test_match_mixed_type_edge_with_scalar_edge(): + edge1 = IOHyperEdge(type=Tensor[int | float] | int | float) # Mixed type edge. + + constr = Constraint(fn=reduce_type_constraint, type=UpdateType.TYPE) + edge2 = IOHyperEdge(type=float | bool) + edge2.add_constraint(constr) + + updates = edge1.match(edge2) + assert edge1.edge_type is float + assert not edge1.is_polymorphic + assert edge1.all_constraints == {constr} and edge2.all_constraints == set() + assert updates.constraints == {constr} + assert updates.value_updates == set() + assert updates.shape_updates == set() + + +def test_match_mixed_type_edge_with_mixed_type_edge_1(): + edge1 = IOHyperEdge(type=Tensor[int | float] | int | float) # Mixed type edge. + + constr = Constraint(fn=reduce_type_constraint, type=UpdateType.TYPE) + edge2 = IOHyperEdge(type=float | bool | Tensor[float | bool]) + edge2.add_constraint(constr) + + updates = edge1.match(edge2) + assert edge1.edge_type == float | Tensor[float] + assert edge1.is_polymorphic + assert edge1.all_constraints == {constr} and edge2.all_constraints == set() + assert updates.constraints == {constr} + assert updates.value_updates == set() + assert updates.shape_updates == set() + + +def test_match_mixed_type_edge_with_mixed_type_edge_2(): + edge1 = IOHyperEdge(type=Tensor[int | float] | int | float) # Mixed type edge. + + constr = Constraint(fn=reduce_type_constraint, type=UpdateType.TYPE) + edge2 = IOHyperEdge(type=float | bool | Tensor[bool]) + edge2.add_constraint(constr) + + updates = edge1.match(edge2) + assert edge1.edge_type is float + assert not edge1.is_polymorphic + assert edge1.all_constraints == {constr} and edge2.all_constraints == set() + assert updates.constraints == {constr} + assert updates.value_updates == set() + assert updates.shape_updates == set() diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index c323ffbe..e0e8acef 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -13,6 +13,7 @@ # limitations under the License. from itertools import product +from typing import get_origin import numpy as np import pytest @@ -662,7 +663,7 @@ def test_iokey_values_12(): model += sig_model_2( input=IOKey(shape=[1, 2, 3, 4], name="input"), output=IOKey(name="output2") ) - assert sig_model_1.input.data.metadata.edge_type is Tensor + assert get_origin(sig_model_1.input.data.metadata.edge_type) is Tensor assert sig_model_1.input.data.metadata.shape is not None assert sig_model_1.input.data.metadata.shape.get_shapes() == [1, 2, 3, 4] @@ -819,8 +820,6 @@ def test_iokey_scalar_output_all_args(): continue try: - if name is None and expose: - ... # try to extend the model model += sub_model(input="input", output=output) except Exception as e: @@ -841,7 +840,11 @@ def test_iokey_scalar_output_all_args(): # Since providing shape within IOKey means it is Tensor type, # it is an expected type error. assert isinstance(e, TypeError) - assert e.args[0] == "Can not set Tensor type to a Scalar edge." + assert ( + e.args[0] == "Acceptable types are tuple[int, ...], but " + "mithril.framework.common.Tensor[int | float | bool] type " + "is provided!" + ) else: # it is an unexpected error. Raise given exception in that case @@ -934,7 +937,12 @@ def test_iokey_scalar_input_all_args(): # Since providing shape within IOKey means it is Tensor type, # it is an expected type error. assert isinstance(e, TypeError) - assert e.args[0] == "Can not set Tensor type to a Scalar edge." + assert ( + e.args[0] + == "Acceptable types are None | int | list[int] | tuple[int, ...], " + "but mithril.framework.common.Tensor[int | float | bool] type " + "is provided!" + ) else: # it is an unexpected error. Raise given exception in that case @@ -1181,7 +1189,12 @@ def test_iokey_shape_error_1(): with pytest.raises(TypeError) as err_info: model += mean_model(axis=IOKey(name="axis", shape=[2, 3])) - assert str(err_info.value) == "Can not set Tensor type to a Scalar edge." + assert ( + str(err_info.value) + == "Acceptable types are None | int | list[int] | tuple[int, ...], but " + "mithril.framework.common.Tensor[int | float | bool] type " + "is provided!" + ) def test_error_1(): diff --git a/tests/scripts/test_process_sequence.py b/tests/scripts/test_process_sequence.py index 5c1f3131..64c81c02 100644 --- a/tests/scripts/test_process_sequence.py +++ b/tests/scripts/test_process_sequence.py @@ -95,7 +95,7 @@ def test_process_value_inconsistent_shape(): def test_tensor_initialization_1(): sequence: list[list[int] | Sequence[int]] = [[0, 1, 2], range(3, 6)] - tensor = Tensor(sequence) + tensor: Tensor[float] = Tensor(sequence) assert tensor.value == [[0, 1, 2], [3, 4, 5]] assert tensor.shape.get_shapes() == [2, 3] assert tensor.type is int @@ -105,7 +105,7 @@ def test_tensor_initialization_2(): sequence: list[list[list[int | float] | Sequence[int]]] = [ [[0, 1.0, 2], range(3, 6)] ] - tensor = Tensor(sequence) + tensor: Tensor[float] = Tensor(sequence) assert tensor.value == [[[0, 1.0, 2], [3, 4, 5]]] assert tensor.shape.get_shapes() == [1, 2, 3] assert tensor.type is float @@ -115,7 +115,7 @@ def test_tensor_initialization_3(): sequence: list[list[list[int | float] | Sequence[int]]] = [ [[0, 1.0, 2], [True, False, True]] ] - tensor = Tensor(sequence) + tensor: Tensor[float] = Tensor(sequence) assert tensor.value == [[[0, 1.0, 2], [True, False, True]]] assert tensor.shape.get_shapes() == [1, 2, 3] assert tensor.type is float diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index 558bb173..93c46666 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -15,6 +15,7 @@ import inspect import json from copy import deepcopy +from typing import get_origin import numpy as np import pytest @@ -179,7 +180,7 @@ def test_randomized(case: str) -> None: ) static_inputs[init_key] = { key: init_backend.array(value) - if model.conns.get_metadata(key).edge_type is Tensor + if get_origin(model.conns.get_metadata(key).edge_type) is Tensor else value for key, value in static_inputs[init_key].items() } @@ -229,7 +230,7 @@ def test_randomized(case: str) -> None: } static_inputs[backend.backend_type] = { key: backend.array(value) - if model.conns.get_metadata(key).edge_type is Tensor + if get_origin(model.conns.get_metadata(key).edge_type) is Tensor else value for key, value in static_inputs[init_key].items() } diff --git a/tests/scripts/test_ref_counts.py b/tests/scripts/test_ref_counts.py index daf15cfc..6b788370 100644 --- a/tests/scripts/test_ref_counts.py +++ b/tests/scripts/test_ref_counts.py @@ -15,6 +15,7 @@ import sys from copy import deepcopy +from typing import get_origin from mithril.framework.common import ( NOT_GIVEN, @@ -89,8 +90,8 @@ def __init__(self) -> None: submodel1 = TestModel() submodel2 = TestModel() - assert submodel1.output.metadata.edge_type is Tensor - assert submodel2.output.metadata.edge_type is Tensor + assert get_origin(submodel1.output.metadata.edge_type) is Tensor + assert get_origin(submodel2.output.metadata.edge_type) is Tensor assert submodel1.output.metadata.shape is not None assert submodel2.input.metadata.shape is not None ref_var1 = next(iter(submodel1.output.metadata.shape.reprs)).root @@ -121,8 +122,8 @@ def __init__(self) -> None: buff_model1 = MyModel() buff_model2 = MyModel() - assert buff_model1.input.metadata.edge_type is Tensor - assert buff_model2.input.metadata.edge_type is Tensor + assert get_origin(buff_model1.input.metadata.edge_type) is Tensor + assert get_origin(buff_model2.input.metadata.edge_type) is Tensor assert buff_model1.input.metadata.shape is not None assert buff_model2.output.metadata.shape is not None ref_var1 = next(iter(buff_model1.input.metadata.shape.reprs)).root @@ -1378,8 +1379,8 @@ def __init__(self) -> None: buff_model1 = MyModel() buff_model2 = MyModel() - assert buff_model1.output.metadata.edge_type is Tensor - assert buff_model2.input.metadata.edge_type is Tensor + assert get_origin(buff_model1.output.metadata.edge_type) is Tensor + assert get_origin(buff_model2.input.metadata.edge_type) is Tensor assert buff_model1.output.metadata.shape is not None assert buff_model2.input.metadata.shape is not None ref_var1 = next(iter(buff_model1.output.metadata.shape.reprs))[0] @@ -1391,7 +1392,7 @@ def __init__(self) -> None: diff_roots = set() for tensor in get_all_data(model): - assert tensor.edge_type is Tensor + assert get_origin(tensor.edge_type) is Tensor node = tensor.shape assert node is not None for repr in node.reprs: @@ -1419,8 +1420,8 @@ def __init__(self) -> None: submodel1 = MyModel() submodel2 = MyModel() - assert submodel1.output.metadata.edge_type is Tensor - assert submodel2.input.metadata.edge_type is Tensor + assert get_origin(submodel1.output.metadata.edge_type) is Tensor + assert get_origin(submodel2.input.metadata.edge_type) is Tensor assert submodel1.output.metadata.shape is not None assert submodel2.input.metadata.shape is not None ref_var1 = next(iter(submodel1.output.metadata.shape.reprs))[0] @@ -1451,8 +1452,8 @@ def __init__(self) -> None: submodel1 = MyModel() submodel2 = MyModel() - assert submodel1.output.metadata.edge_type is Tensor - assert submodel2.input.metadata.edge_type is Tensor + assert get_origin(submodel1.output.metadata.edge_type) is Tensor + assert get_origin(submodel2.input.metadata.edge_type) is Tensor assert submodel1.output.metadata.shape is not None assert submodel2.input.metadata.shape is not None ref_var1 = next(iter(submodel1.output.metadata.shape.reprs)) @@ -1483,8 +1484,8 @@ def __init__(self) -> None: submodel1 = MyModel() submodel2 = MyModel() - assert submodel1.output.metadata.edge_type is Tensor - assert submodel2.input.metadata.edge_type is Tensor + assert get_origin(submodel1.output.metadata.edge_type) is Tensor + assert get_origin(submodel2.input.metadata.edge_type) is Tensor ref_var1 = submodel1.output.metadata.shape ref_var2 = submodel2.input.metadata.shape diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 07666901..3b34418b 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -19,6 +19,7 @@ import typing from copy import deepcopy from functools import partial +from typing import get_origin import jax import mlx.core as mx @@ -1942,7 +1943,9 @@ def test_regularization_5(): output=IOKey(name="output"), ) model += Multiply()( - left=IOKey("left1", type=Tensor), right="w", output=IOKey(name="output2") + left=IOKey("left1", type=Tensor), + right="w", + output=IOKey(name="output2"), ) ctx = TrainModel(model) @@ -2010,7 +2013,9 @@ def test_static_anlaysis_1(): right=IOKey(value=Tensor([2.0]), name="right"), ) model += Add()( - left=add1.output, right=IOKey(type=Tensor), output=IOKey(name="output1") + left=add1.output, + right=IOKey(type=Tensor), + output=IOKey(name="output1"), ) comp_model = mithril.compile( @@ -2037,7 +2042,9 @@ def test_static_anlaysis_2(): ) model += sum1(input=add1.output) model += Add()( - left=sum1.output, right=IOKey(type=Tensor), output=IOKey(name="output1") + left=sum1.output, + right=IOKey(type=Tensor), + output=IOKey(name="output1"), ) comp_model = mithril.compile( @@ -2185,7 +2192,10 @@ def test_prune_4(): add2 = Add() add3 = Add() - m += add0(left=IOKey("input", type=Tensor), right=IOKey("input2", type=Tensor)) + m += add0( + left=IOKey("input", type=Tensor), + right=IOKey("input2", type=Tensor), + ) m += add1(left="input", right="input2") # Duplicate m += add2(left=add0.output, right=add0.output) m += add3(left=add1.output, right=add1.output) # Duplicate @@ -2217,7 +2227,10 @@ def test_prune_5(): add2 = Add() add3 = Add() add4 = Add() - m += add0(left=IOKey("input", type=Tensor), right=IOKey("input2", type=Tensor)) + m += add0( + left=IOKey("input", type=Tensor), + right=IOKey("input2", type=Tensor), + ) m += add1(left="input", right="input2") # Duplicate m += add2(left=add0.output, right=add1.output) m += Add()(left=add1.output, right=add0.output) @@ -2246,13 +2259,17 @@ def test_prune_5(): def test_prune_6(): m1 = Model() add0 = Add() - m1 += add0(left=IOKey("input", type=Tensor), right=IOKey("input2", type=Tensor)) + m1 += add0( + left=IOKey("input", type=Tensor), + right=IOKey("input2", type=Tensor), + ) m1 += Add()(left=add0.output, right=add0.output, output=IOKey(name="output")) m2 = Model() add0 = Add() m2 += add0( - left=IOKey("input", type=Tensor), right=IOKey("input2", type=Tensor) + left=IOKey("input", type=Tensor), + right=IOKey("input2", type=Tensor), ) # Duplicate m2 += Multiply()(left=add0.output, right=add0.output, output=IOKey(name="output")) @@ -2520,7 +2537,9 @@ def test_prune_valued_tensor_1(): # Values different do not prune! model = Model() model += Add()( - left=Tensor(5), right=IOKey("input2", type=Tensor), output=IOKey("output1") + left=Tensor(5), + right=IOKey("input2", type=Tensor), + output=IOKey("output1"), ) model += Add()(left=Tensor(3), right="input2", output=IOKey("output2")) @@ -2541,7 +2560,9 @@ def test_prune_valued_tensor_2(): # Values same prune! model = Model() model += Add()( - left=Tensor(3), right=IOKey("input2", type=Tensor), output=IOKey("output1") + left=Tensor(3), + right=IOKey("input2", type=Tensor), + output=IOKey("output1"), ) model += Add()(left=Tensor(3), right="input2", output=IOKey("output2")) @@ -2568,7 +2589,9 @@ def test_prune_valued_tensor_3(): output=IOKey("output1"), ) model += Add()( - left=IOKey("left2", type=Tensor), right="input2", output=IOKey("output2") + left=IOKey("left2", type=Tensor), + right="input2", + output=IOKey("output2"), ) backend = JaxBackend(dtype=mithril.float64) @@ -2599,7 +2622,9 @@ def test_prune_valued_tensor_4(): output=IOKey("output1"), ) model += Add()( - left=IOKey("left2", type=Tensor), right="input3", output=IOKey("output2") + left=IOKey("left2", type=Tensor), + right="input3", + output=IOKey("output2"), ) backend = JaxBackend(dtype=mithril.float64) @@ -5774,7 +5799,7 @@ def test_deepcopy_1(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert id(data.value) == id(copied_data.value) @@ -5802,7 +5827,7 @@ def test_deepcopy_2(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert id(data.value) == id(copied_data.value) @@ -5830,7 +5855,7 @@ def test_deepcopy_3(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert id(data.value) == id(copied_data.value) @@ -5855,7 +5880,7 @@ def test_deepcopy_4(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert id(data.value) == id(copied_data.value) @@ -5891,7 +5916,7 @@ def test_deepcopy_5(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if data.edge_type is Tensor: + if get_origin(data.edge_type) is Tensor: assert id(data.value) == id(copied_data.value) diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index 14741227..ffe14582 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -16,6 +16,7 @@ from copy import deepcopy from itertools import combinations, product from types import EllipsisType, NoneType +from typing import get_origin import numpy as np import pytest @@ -769,7 +770,7 @@ def test_simple_composite_1_set_shapes(): def test_simple_composite_1_extend_inputs(): model = Model() mult = Multiply() - right_input = Tensor(np.random.randn(2, 2).tolist()) + right_input: Tensor[float] = Tensor(np.random.randn(2, 2).tolist()) model += mult( left=IOKey(value=Tensor([[2.0]]), name="left"), right=IOKey(value=right_input, name="right"), @@ -914,7 +915,7 @@ def test_simple_composite_2_set_shapes_2(): def test_simple_composite_2_extend_inputs(): model = Model() mult = Multiply() - Multiply_0_right = Tensor(np.random.randn(2, 2).tolist()) + Multiply_0_right: Tensor[float] = Tensor(np.random.randn(2, 2).tolist()) model += mult( left=IOKey(value=Tensor(2.0), name="left"), right=IOKey(value=Multiply_0_right, name="in1"), @@ -948,7 +949,8 @@ def test_simple_composite_2_static_shapes(): model = Model() mult = Multiply() model += mult( - left=IOKey(value=Tensor(2.0), name="left"), right=IOKey("in1", type=Tensor) + left=IOKey(value=Tensor(2.0), name="left"), + right=IOKey("in1", type=Tensor), ) model += Divide()( numerator=IOKey(value=Tensor(2.0), name="numerator"), @@ -979,7 +981,8 @@ def test_simple_composite_2_static_inputs(): model = Model() mult = Multiply() model += mult( - left=IOKey(value=Tensor(2.0), name="left"), right=IOKey("in1", type=Tensor) + left=IOKey(value=Tensor(2.0), name="left"), + right=IOKey("in1", type=Tensor), ) model += Divide()( numerator=IOKey(value=Tensor(2.0), name="numerator"), @@ -1342,7 +1345,8 @@ def get_composite_1(): # Create common composite_1 model for corresponding tests. composite_1 = Model() composite_1 += (m1 := Multiply())( - left=IOKey("input1", type=Tensor), right=IOKey("input2", type=Tensor) + left=IOKey("input1", type=Tensor), + right=IOKey("input2", type=Tensor), ) composite_1 += (m2 := Multiply())(left="input2", right=m1.output) composite_1 += Add()(left=m2.output, right=m2.output, output=IOKey(name="output")) @@ -1373,8 +1377,10 @@ def test_composite_1_static_shapes_1(): def test_composite_1_extend_inputs_1(): composite = Model() m1 = Multiply() - Multiply_0_left = Tensor(np.random.randn(1, 1, 1, 1, 1, 1, 1, 37, 43).tolist()) - Multiply_0_right = Tensor(np.random.randn(134, 47, 1, 1, 1).tolist()) + Multiply_0_left: Tensor[float] = Tensor( + np.random.randn(1, 1, 1, 1, 1, 1, 1, 37, 43).tolist() + ) + Multiply_0_right: Tensor[float] = Tensor(np.random.randn(134, 47, 1, 1, 1).tolist()) composite += m1( left=IOKey(value=Multiply_0_left, name="left"), right=IOKey(value=Multiply_0_right, name="right"), @@ -1659,7 +1665,8 @@ def test_composite_2_set_shapes_3(): m2 += Add()(left=mult4.output, right=mult4.output, output=IOKey(name="output")) m3 += (add1 := Add())( - left=IOKey("input1", type=Tensor), right=IOKey("input2", type=Tensor) + left=IOKey("input1", type=Tensor), + right=IOKey("input2", type=Tensor), ) m3 += (mult5 := Multiply())(left="input2", right=add1.output) m3 += Add()(left=mult5.output, right=mult5.output, output=IOKey(name="output")) @@ -2273,8 +2280,8 @@ def test_composite_3_extend_shapes_1(): composite_3 = Model() m1 = Model() add1 = Add() - add_1_left = Tensor(np.random.randn(3, 4, 5, 6, 1).tolist()) - add_1_right = Tensor(np.random.randn(1, 1, 1, 1, 7).tolist()) + add_1_left: Tensor[float] = Tensor(np.random.randn(3, 4, 5, 6, 1).tolist()) + add_1_right: Tensor[float] = Tensor(np.random.randn(1, 1, 1, 1, 7).tolist()) m1 += add1( left=IOKey(value=add_1_left, name="left"), right=IOKey(value=add_1_right, name="right"), @@ -4124,7 +4131,10 @@ def __init__(self) -> None: input1=BaseKey(shape=["u1", "u2", "u3", ("Var1", ...)], type=Tensor), input2=BaseKey(shape=["u4", "u5", ("Var2", ...), "u6"], type=Tensor), input3=BaseKey(shape=["u7", ("Var3", ...), "u8", "u9"], type=Tensor), - input4=BaseKey(shape=[("Var4", ...), "u10", "u11", "u12"], type=Tensor), + input4=BaseKey( + shape=[("Var4", ...), "u10", "u11", "u12"], + type=Tensor, + ), input5=BaseKey( shape=[("Var5", ...), "u13", "u14", "u15", "u16"], type=Tensor, @@ -4205,8 +4215,14 @@ def __init__(self) -> None: input1=BaseKey(shape=["u1", "u2", ("Var1", ...)], type=Tensor), input2=BaseKey(shape=["u3", ("Var2", ...), "u4"], type=Tensor), input3=BaseKey(shape=[("Var3", ...), "u5", "u6"], type=Tensor), - input4=BaseKey(shape=["u7", "u8", ("Var4", ...), "u9", "u10"], type=Tensor), - input5=BaseKey(shape=["u11", ("Var4", ...), "u12", "u13"], type=Tensor), + input4=BaseKey( + shape=["u7", "u8", ("Var4", ...), "u9", "u10"], + type=Tensor, + ), + input5=BaseKey( + shape=["u11", ("Var4", ...), "u12", "u13"], + type=Tensor, + ), output=BaseKey(shape=["u5", "u5"], type=Tensor), ) @@ -4421,7 +4437,7 @@ def test_total_repr_count(): edge = var2.input.data.metadata - assert edge.edge_type is Tensor + assert get_origin(edge.edge_type) is Tensor assert edge.shape is not None assert len(edge.shape.reprs) == 2 @@ -4429,7 +4445,7 @@ def test_total_repr_count(): def test_total_repr_count_linear_1(): model = Linear() edge = model.input.metadata - assert edge.edge_type is Tensor + assert get_origin(edge.edge_type) is Tensor assert edge.shape is not None shp_repr = next(iter(edge.shape.reprs)) @@ -5120,10 +5136,13 @@ def test_variadic_naming_12(): def test_variadic_naming_13(): model = Model() model += (mult := MatrixMultiply())( - left=IOKey("input", type=Tensor), right=IOKey("w", type=Tensor) + left=IOKey("input", type=Tensor), + right=IOKey("w", type=Tensor), ) model += Add()( - left=mult.output, right=IOKey("b", type=Tensor), output=IOKey(name="output") + left=mult.output, + right=IOKey("b", type=Tensor), + output=IOKey(name="output"), ) shapes: dict[str, list] = { @@ -5540,8 +5559,14 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=BaseKey(shape=[("Var1", ...), "a", "b", "c"], type=Tensor), - output=BaseKey(shape=["c", ("Var1", ...), "a", "b"], type=Tensor), + input=BaseKey( + shape=[("Var1", ...), "a", "b", "c"], + type=Tensor, + ), + output=BaseKey( + shape=["c", ("Var1", ...), "a", "b"], + type=Tensor, + ), ) def __call__( # type: ignore[override] @@ -5676,7 +5701,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=BaseKey(shape=[("Var1", ...), "u1", "u2", "u3"], type=Tensor), + input=BaseKey( + shape=[("Var1", ...), "u1", "u2", "u3"], + type=Tensor, + ), output=BaseKey(shape=[("Var1", ...), "u4"], type=Tensor), ) @@ -5696,12 +5724,12 @@ def __call__( # type: ignore[override] model.set_shapes(shape_1) model.set_shapes(shape_2) in_data = model.input.metadata - assert in_data.edge_type is Tensor + assert get_origin(in_data.edge_type) is Tensor assert (node := in_data.shape) is not None input_repr = next(iter(node.reprs)) out_data = model.output.metadata - assert out_data.edge_type is Tensor + assert get_origin(out_data.edge_type) is Tensor assert out_data.shape is not None output_repr = next(iter(out_data.shape.reprs)) @@ -5716,7 +5744,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=BaseKey(shape=[("Var1", ...), "u1", "u2", "u3"], type=Tensor), + input=BaseKey( + shape=[("Var1", ...), "u1", "u2", "u3"], + type=Tensor, + ), output=BaseKey(shape=[("Var1", ...), "u4"], type=Tensor), ) @@ -5737,10 +5768,10 @@ def __call__( # type: ignore[override] model.set_shapes(shape_1) model.set_shapes(shape_2) - assert model.input.metadata.edge_type is Tensor + assert get_origin(model.input.metadata.edge_type) is Tensor assert (in_node := model.input.metadata.shape) is not None input_repr = next(iter(in_node.reprs)) - assert model.output.metadata.edge_type is Tensor + assert get_origin(model.output.metadata.edge_type) is Tensor assert (out_node := model.output.metadata.shape) is not None output_repr = next(iter(out_node.reprs)) @@ -5825,7 +5856,7 @@ def test_same_uniadic_5(): buffer.set_shapes(shape_1) buffer.set_shapes(shape_2) - assert buffer.input.metadata.edge_type is Tensor + assert get_origin(buffer.input.metadata.edge_type) is Tensor assert buffer.input.metadata.shape is not None input_reprs = buffer.input.metadata.shape.reprs @@ -6695,7 +6726,7 @@ def find_all_reprs(repr: ShapeRepr, repr_cache=None) -> set[ShapeRepr]: all_tensor_conns = { con for con in model.conns.all.values() - if con.metadata.edge_type is Tensor + if get_origin(con.metadata.edge_type) is Tensor } # Find all reprs that are linked to shape reprs of the tensors @@ -6833,18 +6864,18 @@ def test_node_count_1(): # Check total existing node count all_nodes = set() for con in model.conns.all.values(): - assert con.metadata.edge_type is Tensor + assert get_origin(con.metadata.edge_type) is Tensor all_nodes.add(con.metadata.shape) assert len(all_nodes) == 1 # Check total variadics repr count - assert sub_model.input.metadata.edge_type is Tensor + assert get_origin(sub_model.input.metadata.edge_type) is Tensor assert (in_node := sub_model.input.metadata.shape) is not None assert (in_repr := next(iter(in_node.reprs))) is not None assert in_repr.root is not None assert len(in_repr.root.reprs) == 1 - assert sub_model.output.metadata.edge_type is Tensor + assert get_origin(sub_model.output.metadata.edge_type) is Tensor assert (out_node := sub_model.output.metadata.shape) is not None assert (out_repr := next(iter(out_node.reprs))) is not None assert out_repr.root is not None @@ -6864,7 +6895,7 @@ def test_node_count_2(): all_nodes = set() for con in model.conns.all.values(): edge = con.metadata - assert edge.edge_type is Tensor + assert get_origin(edge.edge_type) is Tensor all_nodes.add(edge.shape) assert len(all_nodes) == 2 @@ -6886,7 +6917,7 @@ def test_node_count_3(): shapes = set() for con in model.conns.all.values(): edge = con.metadata - assert edge.edge_type is Tensor + assert get_origin(edge.edge_type) is Tensor shapes.add(edge.shape) assert len(shapes) == 3 @@ -6919,7 +6950,10 @@ def __init__(self) -> None: super().__init__( formula_key="buffer", input=BaseKey(shape=["a", ("Var1", ...)], type=Tensor), - output=BaseKey(shape=[("Var1", ...), "c", "d", "e"], type=Tensor), + output=BaseKey( + shape=[("Var1", ...), "c", "d", "e"], + type=Tensor, + ), ) model = Model() @@ -6941,7 +6975,10 @@ class MyModel(PrimitiveModel): def __init__(self) -> None: super().__init__( formula_key="buffer", - input=BaseKey(shape=[("Var1", ...), "c", "d", "e"], type=Tensor), + input=BaseKey( + shape=[("Var1", ...), "c", "d", "e"], + type=Tensor, + ), output=BaseKey(shape=["a", "b", ("Var1", ...)], type=Tensor), ) @@ -7419,7 +7456,7 @@ def __call__( # type: ignore[override] model += Buffer() model += test_model all_nodes = get_all_nodes(model) - assert buff_model.input.metadata.edge_type is Tensor + assert get_origin(buff_model.input.metadata.edge_type) is Tensor ref_all_nodes = { test_model.output.metadata.shape, # type: ignore buff_model.input.metadata.shape, @@ -7459,7 +7496,7 @@ def __call__( # type: ignore[override] all_nodes = get_all_nodes(model) data = buff_model.input.metadata - assert data.edge_type is Tensor + assert get_origin(data.edge_type) is Tensor ref_all_nodes = {data.shape} assert all_nodes == ref_all_nodes @@ -7475,7 +7512,7 @@ def test_node_count_6(): all_nodes = get_all_nodes(model) data = buff_model.input.metadata - assert data.edge_type is Tensor + assert get_origin(data.edge_type) is Tensor ref_all_nodes = {data.shape} assert all_nodes == ref_all_nodes @@ -7490,7 +7527,7 @@ def test_node_count_7(): for _ in range(5): model += deepcopy(model) all_nodes = get_all_nodes(model) - assert buff_model.input.metadata.edge_type is Tensor + assert get_origin(buff_model.input.metadata.edge_type) is Tensor ref_all_nodes = {buff_model.input.metadata.shape} assert all_nodes == ref_all_nodes @@ -7533,13 +7570,16 @@ def test_node_count_10(): buff_model3 = Buffer() submodel1 += buff_model1( - input=IOKey("input1", type=Tensor), output=IOKey(name="output1") + input=IOKey("input1", type=Tensor), + output=IOKey(name="output1"), ) submodel1 += buff_model2( - input=IOKey("input2", type=Tensor), output=IOKey(name="output2") + input=IOKey("input2", type=Tensor), + output=IOKey(name="output2"), ) submodel1 += buff_model3( - input=IOKey("input3", type=Tensor), output=IOKey(name="output3") + input=IOKey("input3", type=Tensor), + output=IOKey(name="output3"), ) model = Model() @@ -7681,7 +7721,7 @@ def __call__( # type: ignore[override] all_nodes = get_all_nodes(model) data = test_model.input.metadata - assert data.edge_type is Tensor + assert get_origin(data.edge_type) is Tensor ref_all_nodes = {data.shape} assert all_nodes == ref_all_nodes @@ -7710,7 +7750,7 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert test_model.input.metadata.edge_type is Tensor + assert get_origin(test_model.input.metadata.edge_type) is Tensor ref_all_nodes = {test_model.input.metadata.shape} assert all_nodes == ref_all_nodes @@ -7740,9 +7780,9 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert test_model1.input.metadata.edge_type is Tensor - assert test_model2.input.metadata.edge_type is Tensor - assert test_model3.input.metadata.edge_type is Tensor + assert get_origin(test_model1.input.metadata.edge_type) is Tensor + assert get_origin(test_model2.input.metadata.edge_type) is Tensor + assert get_origin(test_model3.input.metadata.edge_type) is Tensor ref_all_nodes = { test_model1.input.metadata.shape, test_model2.input.metadata.shape, @@ -7776,9 +7816,9 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert test_model1.input.metadata.edge_type is Tensor - assert test_model2.input.metadata.edge_type is Tensor - assert test_model3.input.metadata.edge_type is Tensor + assert get_origin(test_model1.input.metadata.edge_type) is Tensor + assert get_origin(test_model2.input.metadata.edge_type) is Tensor + assert get_origin(test_model3.input.metadata.edge_type) is Tensor ref_all_nodes = { test_model1.input.metadata.shape, test_model2.input.metadata.shape, @@ -7812,9 +7852,9 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert test_model1.input.metadata.edge_type is Tensor - assert test_model2.input.metadata.edge_type is Tensor - assert test_model3.input.metadata.edge_type is Tensor + assert get_origin(test_model1.input.metadata.edge_type) is Tensor + assert get_origin(test_model2.input.metadata.edge_type) is Tensor + assert get_origin(test_model3.input.metadata.edge_type) is Tensor ref_all_nodes = { test_model1.input.metadata.shape, test_model2.input.metadata.shape, @@ -7853,7 +7893,7 @@ def test_uniadic_repr_count_2(): } model.set_shapes(shape_1) - assert buff_model1.input.metadata.edge_type is Tensor + assert get_origin(buff_model1.input.metadata.edge_type) is Tensor data_shape = buff_model1.input.metadata.shape assert data_shape is not None @@ -7888,7 +7928,7 @@ def test_uniadic_repr_count_3(): } ) - assert buff_model1.input.metadata.edge_type is Tensor + assert get_origin(buff_model1.input.metadata.edge_type) is Tensor data_shape = buff_model1.input.metadata.shape assert data_shape is not None @@ -7997,7 +8037,7 @@ def test_uniadic_repr_count_5(): model.set_shapes(shapes) - assert buff_model1.input.metadata.edge_type is Tensor + assert get_origin(buff_model1.input.metadata.edge_type) is Tensor data_shape = buff_model1.input.metadata.shape assert data_shape is not None @@ -8123,8 +8163,8 @@ def test_add_model_with_scalar_input(): model = Model() add1 = Add() - left_input = Tensor(np.ones((3, 4, 5, 6, 7)).tolist()) - right_input = Tensor(np.ones((3, 4, 5, 6, 7)).tolist()) + left_input: Tensor[float] = Tensor(np.ones((3, 4, 5, 6, 7)).tolist()) + right_input: Tensor[float] = Tensor(np.ones((3, 4, 5, 6, 7)).tolist()) model += add1(left=left_input, right=right_input, output=IOKey(name="output")) assert_all_nodes_unique(model) @@ -8192,7 +8232,7 @@ def test_possible_uniadic_values_directed_8(): buff_model = Buffer() buff_model.set_shapes({"input": ["a", "b"]}) - assert buff_model.input.metadata.edge_type is Tensor + assert get_origin(buff_model.input.metadata.edge_type) is Tensor data_shape = buff_model.input.metadata.shape assert data_shape is not None @@ -8212,7 +8252,7 @@ def test_possible_uniadic_values_directed_9(): buff_model = Buffer() buff_model.set_shapes({"input": ["a", "b", "c", "d"]}) - assert buff_model.input.metadata.edge_type is Tensor + assert get_origin(buff_model.input.metadata.edge_type) is Tensor data_shape = buff_model.input.metadata.shape assert data_shape is not None @@ -9676,7 +9716,7 @@ def test_remove_variadic(): with pytest.raises(Exception) as err_info: # model.shape_map["output"].remove_variadic([Uniadic(5)]) data = model.conns.get_data("output") - assert data.edge_type is Tensor + assert get_origin(data.edge_type) is Tensor data_shape = data.shape assert data_shape is not None next(iter(data_shape.reprs)).remove_variadic([Uniadic(5)]) @@ -9693,7 +9733,7 @@ def test_bcast_left(): } model.set_shapes(shape_1) - assert model.output.metadata.edge_type is Tensor + assert get_origin(model.output.metadata.edge_type) is Tensor data_shape = model.output.metadata.shape assert data_shape is not None assert data_shape.get_shapes() == [2, "u1", "(V1, ...)"] @@ -9720,7 +9760,7 @@ def test_bcast_left_2(): } model.set_shapes(shape_1) - assert model.output.metadata.edge_type is Tensor + assert get_origin(model.output.metadata.edge_type) is Tensor data_shape = model.output.metadata.shape assert data_shape is not None @@ -9740,7 +9780,7 @@ def test_bcast_left_3(): } ) - assert model.output.metadata.edge_type is Tensor + assert get_origin(model.output.metadata.edge_type) is Tensor data_shape = model.output.metadata.shape assert data_shape is not None diff --git a/tests/scripts/test_type_coercion.py b/tests/scripts/test_type_coercion.py index b8d1bdf8..16790b32 100644 --- a/tests/scripts/test_type_coercion.py +++ b/tests/scripts/test_type_coercion.py @@ -164,7 +164,9 @@ def test_scalar_to_tensor_3(): model += shp_1(input=add_1.output) model += tensor_2(input=shp_1.output) model += Add()( - left=IOKey("left", type=Tensor), right=tensor_2.output, output="output" + left=IOKey("left", type=Tensor), + right=tensor_2.output, + output="output", ) model_1 = model @@ -174,11 +176,14 @@ def test_scalar_to_tensor_3(): add_2 = Add() shp_2 = Shape() model += add_2( - left=IOKey(value=[[[1]]]).tensor(), right=IOKey("right", type=Tensor) + left=IOKey(value=[[[1]]]).tensor(), + right=IOKey("right", type=Tensor), ) model += shp_2(input=add_2.output) model += Add()( - left=IOKey("left", type=Tensor), right=shp_2.output.tensor(), output="output" + left=IOKey("left", type=Tensor), + right=shp_2.output.tensor(), + output="output", ) model_2 = model @@ -846,18 +851,9 @@ def test_type_propagation_floor_divide_4(): input=floor_div.output, output=IOKey(name="output") ) - assert ( - str(error_info.value) - == ( - "Acceptable types are int | float, but type value is " - "provided!" - ) - ) or ( - str(error_info.value) - == ( - "Acceptable types are float | int, but type value is " - "provided!" - ) + assert str(error_info.value) == ( + "Acceptable types are mithril.framework.common.Tensor[int | float], " + "but mithril.framework.common.Tensor[bool] type is provided!" ) @@ -1509,7 +1505,8 @@ def test_coercion_3(): reduce_model = Sum(axis=TBD) add_model = Add() model += add_model( - left=IOKey("left", type=Tensor), right=IOKey(value=[0, 1]).tensor() + left=IOKey("left", type=Tensor), + right=IOKey(value=[0, 1]).tensor(), ) model += (to_list := TensorToList())(input=add_model.output) model += reduce_model(input="input", axis=to_list.output, output="output") @@ -1534,7 +1531,8 @@ def test_coercion_4(): reduce_model = Sum(axis=TBD) add_model = Add() model += add_model( - left=IOKey("left", type=Tensor), right=IOKey(value=[0, 1]).tensor() + left=IOKey("left", type=Tensor), + right=IOKey(value=[0, 1]).tensor(), ) model += (to_list := TensorToList())(input=add_model.output) model += reduce_model(input="input", axis=to_list.output, output="output") @@ -1640,9 +1638,3 @@ def test_tensor_to_scalar_template_2(): assert_results_equal(outputs, ref_outputs) assert_results_equal(grads, ref_grads) - - -# def test_find_intersection_type_nested_list_type(): -# type1 = int | float | list | tuple -# type2 = NestedListType(int | float) -# assert find_intersection_type(type1, type2) == type2 diff --git a/tests/scripts/test_type_consistencies.py b/tests/scripts/test_type_consistencies.py index 5a290331..3d49e3a0 100644 --- a/tests/scripts/test_type_consistencies.py +++ b/tests/scripts/test_type_consistencies.py @@ -21,9 +21,14 @@ import torch import mithril -from mithril.framework.common import NOT_GIVEN, BaseKey, ConnectionType -from mithril.framework.utils import ( +from mithril.framework.common import ( + NOT_GIVEN, + BaseKey, + ConnectionType, + ToBeDetermined, find_intersection_type, +) +from mithril.framework.utils import ( find_type, infer_all_possible_types, sort_type, @@ -144,7 +149,7 @@ def test_default_given_extend_4_numpy_error(): assert str(err_info.value) == ( "Acceptable types are , " - "but type value is provided!" + "but type is provided!" ) @@ -172,7 +177,7 @@ def test_constant_backendvar_numpy(): model += other_model(input=model.mult_out, axis=model.axis) # type: ignore assert str(err_info.value) == ( "Acceptable types are , " - "but type value is provided!" + "but type is provided!" ) @@ -241,7 +246,7 @@ def test_type_6(): model += test_model_2(input1=test_model_1.output) # type: ignore assert str(err_info.value) == ( "Acceptable types are tuple[tuple[int, ...]], but tuple[int, ...] type " - "value is provided!" + "is provided!" ) @@ -259,7 +264,7 @@ def test_type_7(): model += test_model_3(input1="", input3="input1") assert ( str(err_info.value) - == "Acceptable types are , but float | str type value is provided!" + == "Acceptable types are , but float | str type is provided!" ) @@ -277,7 +282,7 @@ def test_type_8(): model += model2(input1="input1") assert str(err_info.value) == ( "Acceptable types are tuple[int, int, int, int], but float | int type " - "value is provided!" + "is provided!" ) @@ -398,8 +403,7 @@ def test_type_16(): output=IOKey(name="output2"), ) assert str(err_info.value) == ( - "Acceptable types are , but type value " - "is provided!" + "Acceptable types are , but type " "is provided!" ) @@ -421,8 +425,7 @@ def test_type_17(): output=IOKey(name="output2"), ) assert str(err_info.value) == ( - "Acceptable types are , " - "but type value is provided!" + "Acceptable types are , " "but type is provided!" ) @@ -813,6 +816,48 @@ def test_find_intersection_types_34(): assert find_intersection_type(type_1, type_2) is None +def test_find_intersection_types_35(): + type_1 = Tensor[int] | int + type_2 = Tensor[int] + assert find_intersection_type(type_1, type_2) is Tensor[int] + + +def test_find_intersection_types_36(): + type_1 = Tensor[int] | int + type_2 = Tensor[int] | int | float + assert find_intersection_type(type_1, type_2) == Tensor[int] | int + + +def test_find_intersection_types_37(): + type_1 = Tensor[int] | int + type_2 = ToBeDetermined + assert find_intersection_type(type_1, type_2) == Tensor[int] | int + + +def test_find_intersection_types_38(): + type_1 = Tensor[int] | int | Tensor[float] + type_2 = ToBeDetermined + assert find_intersection_type(type_1, type_2) == Tensor[int] | int | Tensor[float] + + +def test_find_intersection_types_39(): + type_1 = Tensor[int] | int | Tensor[int | float] + type_2 = Tensor[int] | int | float + assert find_intersection_type(type_1, type_2) == Tensor[int] | int + + +def test_find_intersection_types_40(): + type_1 = Tensor[int] | int | Tensor[int | float] + type_2 = Tensor[int | float] | int | float + assert find_intersection_type(type_1, type_2) == Tensor[int | float] | int + + +def test_find_intersection_types_41(): + type_1 = Tensor[int] | Tensor[int | float] + type_2 = Tensor[int | float] + assert find_intersection_type(type_1, type_2) == Tensor[int | float] + + def test_find_type_1(): input = (3, 4) typ = find_type(input) diff --git a/tests/scripts/test_utils.py b/tests/scripts/test_utils.py index 275b8170..305c0051 100644 --- a/tests/scripts/test_utils.py +++ b/tests/scripts/test_utils.py @@ -15,7 +15,7 @@ import inspect import re from collections.abc import Callable, Mapping, Sequence -from typing import Any +from typing import Any, get_origin import numpy as np @@ -26,10 +26,11 @@ ShapeTemplateType, Tensor, Uniadic, + find_intersection_type, ) from mithril.framework.logical import BaseModel, Model, PrimitiveModel from mithril.framework.physical import PhysicalModel -from mithril.framework.utils import find_intersection_type, find_type +from mithril.framework.utils import find_type from mithril.models.train_model import TrainModel from mithril.utils.dict_conversions import dict_to_model, model_dict from mithril.utils.type_utils import is_list_int @@ -296,7 +297,7 @@ def get_all_nodes(model: BaseModel): node_set = { data.shape for data in all_data - if data.edge_type is Tensor + if get_origin(data.edge_type) is Tensor if data.shape is not None } return node_set From 10bbbcead0bd04d8308ae4640128ed847023bb73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Naz=C4=B1m=20=C3=96nder=20Orhan?= Date: Tue, 28 Jan 2025 15:35:30 +0300 Subject: [PATCH 2/2] Updated upon review feedbacks. --- mithril/framework/codegen/numpy_gen.py | 5 +- mithril/framework/codegen/python_gen.py | 5 +- mithril/framework/common.py | 30 +- mithril/framework/constraints.py | 28 +- mithril/framework/logical/model.py | 4 +- mithril/framework/logical/primitive.py | 14 +- mithril/framework/physical/data_store.py | 8 +- mithril/framework/physical/model.py | 13 +- mithril/models/models.py | 274 +++++++++--------- mithril/utils/dict_conversions.py | 4 +- mithril/utils/func_utils.py | 5 +- tests/scripts/helper.py | 5 +- tests/scripts/test_constr_counter.py | 10 +- tests/scripts/test_constraints.py | 6 +- tests/scripts/test_hyperedge.py | 25 +- tests/scripts/test_io_key.py | 3 +- .../test_randomized_models_all_backends.py | 6 +- tests/scripts/test_ref_counts.py | 27 +- tests/scripts/test_scripts.py | 11 +- tests/scripts/test_shapes.py | 77 +++-- tests/scripts/test_utils.py | 7 +- 21 files changed, 273 insertions(+), 294 deletions(-) diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index a95e2303..b4b41b57 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -16,7 +16,7 @@ import keyword from collections.abc import Callable from functools import partial -from typing import Any, Literal, get_origin, overload +from typing import Any, Literal, overload import numpy as np @@ -32,7 +32,6 @@ IOHyperEdge, LossKey, ParamsEvalType, - Tensor, find_intersection_type, is_type_adjustment_required, ) @@ -318,7 +317,7 @@ def generate_evaluate_gradients( key for key in all_ignored_keys if key in self.pm.data - and get_origin(self.pm.data[key].edge_type) is Tensor + and self.pm.data[key].is_tensor and find_intersection_type(self.pm.data[key].value_type, float) } diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 7dde24fe..35d3e664 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -19,7 +19,7 @@ from collections.abc import Callable from functools import partial from posixpath import basename, splitext -from typing import Any, Generic, Literal, Protocol, get_origin, overload +from typing import Any, Generic, Literal, Protocol, overload from ...backends.backend import ParallelBackend from ...core import DataType, Dtype @@ -30,7 +30,6 @@ EvaluateGradientsType, EvaluateType, ParamsEvalType, - Tensor, ) from ..logical import PrimitiveModel from ..physical.model import PhysicalModel @@ -301,7 +300,7 @@ def generate_imports(self) -> list[ast.stmt]: def is_static_scalar(self, key: str) -> bool: return ( key in self.pm.data_store.cached_data - and get_origin(self.pm.data[key].edge_type) != Tensor + and not self.pm.data[key].is_tensor and self.pm.data[key].edge_type != Dtype and not isinstance(self.pm.data_store.cached_data[key], enum.Enum) ) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index b971baf1..49ed841a 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -471,7 +471,7 @@ def _delete_node(remaining: ShapeNode, deleted: ShapeNode) -> Updates: updates = remaining.merge(deleted) # Iterate over deleted nodes referees to remove deleted node. for ref in deleted.referees: - if get_origin(ref.edge_type) is not Tensor: + if not ref.is_tensor: raise ValueError("Non-tensor edges cannot have any shape.") assert isinstance(ref._value, Tensor) ref._value.shape = remaining @@ -607,7 +607,7 @@ def get_shapes( shapes: dict[str, ShapeTemplateType | list[ShapeTemplateType] | None] = {} for key, data in data_dict.items(): key_name = key_mappings.get(key, key) - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert data.shape is not None shapes[key_name] = data.shape.get_shapes( uniadic_keys, varadic_keys, symbolic, verbose @@ -695,6 +695,12 @@ def find_intersection_type( type_1: type | UnionType | GenericAlias | type[Tensor[int | float | bool]], type_2: type | UnionType | GenericAlias | type[Tensor[int | float | bool]], ) -> type | UnionType | GenericAlias | type[Tensor[int | float | bool]] | None: + # If non-generic Tensor type is provided, convert it to generic Tensor type. + if type_1 is Tensor: + type_1 = Tensor[int | float | bool] + if type_2 is Tensor: + type_2 = Tensor[int | float | bool] + # ToBeDetermined type can be coerced to all types. if type_1 is ToBeDetermined: return type_2 @@ -977,6 +983,10 @@ def is_polymorphic(self) -> bool: scalar_possible = find_intersection_type(ScalarValueType, self._type) return None not in (tensor_possible, scalar_possible) + @property + def is_tensor(self) -> bool: + return get_origin(self._type) is Tensor + @property def is_non_diff(self) -> bool: return not self.differentiable @@ -1081,7 +1091,7 @@ def set_value( self._type is ToBeDetermined or tensor_possible ): raise ValueError("Can not set Tensor value to a Scalar edge.") - if not isinstance(value, Tensor) and get_origin(self._type) is Tensor: + if not isinstance(value, Tensor) and self.is_tensor: raise ValueError("Can not set Scalar value to a Tensor edge.") # If any value different than self._value is provided, raise error. if not self._value_compatible(value): @@ -1475,11 +1485,6 @@ class BaseKey: type: UnionType | type | type[Tensor[int | float | bool]] | ScalarType | None = None interval: list[float | int] | None = None - def __post_init__(self) -> None: - # Convert to generic Tensor type if Tensor type is provided. - if self.type is Tensor: - self.type = Tensor[int | float | bool] - class IOKey(TemplateBase): def __init__( @@ -1586,7 +1591,7 @@ def __eq__(self, other: object) -> bool: def set_differentiable(self, differentiable: bool = True) -> None: # TODO: Move this method to Model class as set_shapes, set_types etc. - if get_origin(self.metadata.edge_type) is Tensor: + if self.metadata.is_tensor: self.metadata.differentiable = differentiable elif differentiable: if self.metadata.edge_type is not ToBeDetermined: @@ -1775,7 +1780,7 @@ def get_key_origin(self, key: str) -> str | None: def get_shape_node(self, key: str) -> ShapeNode: edge = self.get_metadata(key) - if get_origin(edge.edge_type) is not Tensor: + if not edge.is_tensor: raise ValueError("'Only Tensor type connections has shape!'") assert edge.shape is not None return edge.shape @@ -2769,8 +2774,7 @@ def get_most_informative_repr(self) -> ShapeRepr: ) ): most_informative_repr = repr - if most_informative_repr is None: - ... + assert most_informative_repr is not None return most_informative_repr @@ -3123,7 +3127,7 @@ def __call__(self, keys: list[IOHyperEdge]) -> ConstrainResultType: status = False updates = Updates() if self.type == UpdateType.SHAPE: - tensor_keys = [key for key in keys if get_origin(key.edge_type) is Tensor] + tensor_keys = [key for key in keys if key.is_tensor] for reprs in product(*[key.shape.reprs for key in tensor_keys]): # type: ignore for idx, repr in enumerate(reprs): tensor_keys[idx]._temp_shape = repr diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index ff2d9c03..dc54472f 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -179,7 +179,7 @@ def set_edge_type(edge: IOHyperEdge, new_type: Any) -> Updates: # Simply wraps new type into Tensor if edge_type is Tensor, # else sets directly. type = new_type - if get_origin(edge.edge_type) is Tensor: + if edge.is_tensor: type = Tensor[new_type] return edge.set_type(type) @@ -190,9 +190,9 @@ def edge_type_constraint( updates = Updates() status = False tensor_exists: bool = False - tensor_output: bool = get_origin(output.edge_type) is Tensor + tensor_output: bool = output.is_tensor for input in inputs: - if get_origin(input.edge_type) is Tensor: + if input.is_tensor: tensor_exists = True break # First set output edge_type to Tensor if any Tensor type inputs @@ -645,7 +645,7 @@ def indexer_initial_type_constraint( updates |= input.set_type(list[Any] | tuple[Any, ...]) status = True else: - tensor_edge = input if get_origin(input.edge_type) is Tensor else output + tensor_edge = input if input.is_tensor else output assert isinstance(tensor_edge._value, Tensor) typ: type[Tensor[int | float | bool]] = Tensor[tensor_edge.value_type] # type: ignore other_edge = (input, output)[tensor_edge is input] @@ -659,7 +659,7 @@ def indexer_type_constraint( ) -> ConstrainResultType: status = False updates = Updates() - if not (get_origin(input.edge_type) is Tensor or input.edge_type is ToBeDetermined): + if not (input.is_tensor or input.edge_type is ToBeDetermined): # Input is a non-tensor type. input_type = input.value_type output_type = output.value_type @@ -706,7 +706,7 @@ def indexer_type_constraint( updates |= output.set_type(inferred_out_type) status = not is_union(output.value_type) - elif get_origin(input.edge_type) is Tensor: + elif input.is_tensor: status = True return status, updates @@ -1415,9 +1415,9 @@ def bcast_helper( def _bcast( output: IOHyperEdge, left: IOHyperEdge, right: IOHyperEdge, index: int ) -> ConstrainResultType: - l_type = Tensor if get_origin(left.edge_type) is Tensor else left.edge_type - r_type = Tensor if get_origin(right.edge_type) is Tensor else right.edge_type - o_type = Tensor if get_origin(output.edge_type) is Tensor else output.edge_type + l_type = Tensor if left.is_tensor else left.edge_type + r_type = Tensor if right.is_tensor else right.edge_type + o_type = Tensor if output.is_tensor else output.edge_type if l_type is Tensor and r_type is Tensor: assert output._temp_shape is not None, "Output shape of broadcast is not set!" assert left._temp_shape is not None, "Left shape of broadcast is not set!" @@ -3687,7 +3687,7 @@ def tensor_item_constraint_helper( def indexer_constraints( output: IOHyperEdge, input: IOHyperEdge, index: IOHyperEdge ) -> ConstrainResultType: - if get_origin(input.edge_type) is Tensor: + if input.is_tensor: return tensor_item_constraints(output, input, index) elif input.edge_type is not ToBeDetermined: return scalar_item_constraints(output, input, index) @@ -3960,7 +3960,7 @@ def relational_operator_type_constraint( updates = Updates() status = False # Forward inference. - if Tensor in (get_origin(input1.edge_type), get_origin(input2.edge_type)): + if input1.is_tensor or input2.is_tensor: updates |= output.set_type(Tensor[bool]) status = True elif ToBeDetermined not in (input1.edge_type, input2.edge_type): @@ -3975,7 +3975,7 @@ def divide_type_constraint( updates = Updates() status = False # Forward inference. - if Tensor in (get_origin(numerator.edge_type), get_origin(denominator.edge_type)): + if numerator.is_tensor or denominator.is_tensor: updates |= output.set_type(Tensor[float]) status = True elif ToBeDetermined not in (numerator.edge_type, denominator.edge_type): @@ -3993,13 +3993,13 @@ def polynomial_kernel_constraint( # poly_coef update. if poly_coef.edge_type is not ToBeDetermined: coef_status = True - if get_origin(poly_coef.edge_type) is Tensor: + if poly_coef.is_tensor: assert poly_coef.shape is not None updates |= poly_coef.shape.set_values([]) # degree update. if degree.edge_type is not ToBeDetermined: degree_status = True - if get_origin(degree.edge_type) is Tensor: + if degree.is_tensor: assert degree.shape is not None updates |= degree.shape.set_values([]) return coef_status & degree_status, updates diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 60372ea1..fbd9c14b 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -16,7 +16,7 @@ from collections.abc import Mapping from types import UnionType -from typing import Any, Self, get_origin +from typing import Any, Self from ...utils.utils import OrderedSet, find_dominant_type from ..common import ( @@ -718,7 +718,7 @@ def extend( con_obj, _updates = self._add_connection(model, local_key, value, updates) updates |= _updates submodel_dag[local_key] = con_obj - if get_origin(con_obj.metadata.edge_type) is Tensor: + if con_obj.metadata.is_tensor: updates.shape_updates.add(con_obj.metadata) # Replace shape info keys, which are local keys, with global equivalents. diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index 687e324d..b6e95482 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -69,11 +69,16 @@ def __init__( output_data: IOHyperEdge | None = None for key, value in kwargs.items(): if isinstance(value, BaseKey): - if get_origin(value.type) is Tensor: + if get_origin(value.type) is Tensor or value.type is Tensor: + val_type = ( + Tensor[int | float | bool] + if value.type is Tensor + else value.type + ) if not isinstance(tensor := value.value, Tensor): assert isinstance(value.value, ToBeDetermined) tensor = Tensor( - type=get_args(value.type)[0], + type=get_args(val_type)[0], shape=shapes[key].node, ) edge = IOHyperEdge(value=tensor, interval=value.interval) @@ -179,10 +184,7 @@ def extract_connection_info( # try to find outer key's real name in data_to_key_map outer_key = data_to_key_map.get(key_data, [key]) outer_key = ["'" + key + "'" for key in outer_key] - if ( - get_origin(key_data.edge_type) is not Tensor - and key_data.value is not TBD - ): + if not key_data.is_tensor and key_data.value is not TBD: # If value of the scalar is determined, write that value directly. outer_key = [str(key_data.value)] conn.extend(outer_key) diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index 8275670f..601fdbad 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -14,7 +14,7 @@ from collections.abc import Mapping, Sequence from copy import deepcopy -from typing import Any, Generic, TypeGuard, get_origin +from typing import Any, Generic, TypeGuard from ...backends.backend import Backend from ...core import Constant, DataType, Dtype, data_types, epsilon_table @@ -164,7 +164,7 @@ def _set_data_value(self, key: str, data: IOHyperEdge) -> None: if isinstance(value, Constant): value = epsilon_table[self.backend.precision][value] - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: value = self.backend.array(value) elif isinstance(value, Dtype): value = getattr(self.backend, value.name) @@ -224,7 +224,7 @@ def set_shapes( if isinstance(key, Connection): key = key.key assert isinstance(key, str) - if get_origin((data := self._all_data[key]).edge_type) is not Tensor: + if not (data := self._all_data[key]).is_tensor: raise ValueError("Non-tensor data can not have shape!") assert data.shape is not None updates |= data.shape.set_values(value) @@ -272,7 +272,7 @@ def set_static_keys( raise KeyError( "Requires static key to be in the input keys of the model!" ) - if (get_origin(self._all_data[key].edge_type) is Tensor) and not isinstance( + if self._all_data[key].is_tensor and not isinstance( value, ToBeDetermined | self.backend.get_backend_array_type() ): raise ValueError( diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 38828602..61dbbda5 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -20,7 +20,6 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from dataclasses import dataclass -from typing import get_origin from ...backends.backend import Backend, ParallelBackend from ...core import DataType, GenericDataType @@ -204,7 +203,7 @@ def __init__( elif global_key in self._trainable_tensor_inputs: # if physical_data.edge_type not in (Tensor, ToBeDetermined): if not ( - get_origin(physical_data.edge_type) is Tensor + physical_data.is_tensor or physical_data.edge_type is ToBeDetermined ): raise ValueError( @@ -224,7 +223,7 @@ def __init__( if key_shape := model_shapes.get(key): data = model_data[key] - assert get_origin(data.edge_type) is Tensor + assert data.is_tensor shp = data.shape assert shp is not None # assert shp is not None @@ -435,7 +434,7 @@ def _infer_differentiability(self, model_data: dict[str, IOHyperEdge]) -> None: # that have a Tensor type output. output_key = PrimitiveModel.output_key output_edge = model_data[output_key] - if get_origin(output_edge.edge_type) is Tensor: + if output_edge.is_tensor: # If any of the inputs are differentiable, then # the output is also differentiable. for key, value in model_data.items(): @@ -552,7 +551,7 @@ def _pre_compile( for value in self.data_store.intermediate_non_differentiables.inverse: # there can exist some inferred intermediate scalar keys in logical model. # find those keys and add to cached datas - if (get_origin(value.edge_type) is not Tensor) and (value.value is not TBD): + if not value.is_tensor and (value.value is not TBD): updates.add(value) self.data_store.update_cached_data(updates) @@ -607,7 +606,7 @@ def _pre_compile( # but not unnecessary in flat_graph. This case should be handled when # flat_graph - data_store integration is updated. if conn_edge is not None and ( - (get_origin(conn_edge.edge_type) is not Tensor) + (not conn_edge.is_tensor) or ( (not find_intersection_type(float, conn_edge.value_type)) or _key @@ -960,7 +959,7 @@ def extract_connection_info( # model. Indicate it accordingly input_name = "'" + connection.key + "'" input_data = model.conns.all[input_key].metadata - if get_origin(input_data.edge_type) is not Tensor: + if not input_data.is_tensor: # If value of the scalar is determined, write that value pm_input_data = self.data_store.data_memo[id(input_data)] if (val := pm_input_data.value) is not TBD: diff --git a/mithril/models/models.py b/mithril/models/models.py index 1ba2a458..fe933ecc 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -30,8 +30,6 @@ ShapeTemplateType, Tensor, ToBeDetermined, - TypeVarTensorType, - ValType, ) from ..framework.constraints import polynomial_kernel_constraint from ..framework.logical.base import BaseModel, ExtendInfo @@ -166,7 +164,7 @@ def __init__( stride: int | None | ToBeDetermined = None, padding: int | PaddingType | tuple[int, int] | ToBeDetermined = (0, 0), dilation: int | ToBeDetermined = 1, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -258,7 +256,7 @@ def __init__( stride: int | None | tuple[int, int] | ToBeDetermined = None, padding: int | PaddingType | tuple[int, int] | ToBeDetermined = (0, 0), dilation: int | ToBeDetermined = 1, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -358,8 +356,8 @@ def __init__( padding: int | PaddingType | tuple[int, int] | ToBeDetermined = 0, dilation: int | ToBeDetermined = 1, use_bias: bool = True, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -440,8 +438,8 @@ def __init__( | ToBeDetermined = (0, 0), dilation: int | tuple[int, int] | ToBeDetermined = (1, 1), use_bias: bool = True, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -525,9 +523,9 @@ def __init__( self, dimension: int | None = None, use_bias: bool = True, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -546,7 +544,7 @@ def __init__( input_key = IOKey(name="input", value=input) weight_key = IOKey(name="weight", value=weight).transpose() if use_bias: - bias_key = IOKey(name="bias", value=bias, type=Tensor[ValType]) + bias_key = IOKey(name="bias", value=bias, type=Tensor[int | float | bool]) self |= mult(left=input_key, right=weight_key) self |= Add()(left=mult.output, right=bias_key, output=output) shapes["bias"] = [dim] @@ -584,9 +582,9 @@ class ElementWiseAffine(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -632,9 +630,9 @@ def __init__( self, activation: BaseModel, dimension: int | None = None, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -676,9 +674,9 @@ def __init__( use_scale: bool = True, use_bias: bool = True, eps: float = 1e-5, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -688,10 +686,12 @@ def __init__( # Expects its input shape as [B, ..., d] d refers to normalized dimension mean = Mean(axis=-1, keepdim=True) numerator = Subtract() - numerator.set_types(left=Tensor[ValType], right=Tensor[ValType]) + numerator.set_types( + left=Tensor[int | float | bool], right=Tensor[int | float | bool] + ) var = Variance(axis=-1, correction=0, keepdim=True) add = Add() - add.set_types(left=Tensor[ValType]) + add.set_types(left=Tensor[int | float | bool]) denominator = Sqrt() in_key = IOKey("input", value=input) self += mean(input=in_key) @@ -711,13 +711,17 @@ def __init__( if use_scale: mult = Multiply() - mult.set_types(left=Tensor[ValType], right=Tensor[ValType]) + mult.set_types( + left=Tensor[int | float | bool], right=Tensor[int | float | bool] + ) self += mult(left=self.cout, right=IOKey("weight", value=weight)) mult._set_shapes(shapes) if use_bias: add = Add() - add.set_types(left=Tensor[ValType], right=Tensor[ValType]) + add.set_types( + left=Tensor[int | float | bool], right=Tensor[int | float | bool] + ) self += add(left=self.cout, right=IOKey("bias", value=bias)) add._set_shapes(shapes) # TODO: Remove below Buffer after required naming-related changes are done. @@ -758,10 +762,10 @@ def __init__( use_scale: bool = True, use_bias: bool = True, eps: float = 1e-5, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, name: str | None = None, ) -> None: super().__init__(name=name) @@ -832,7 +836,7 @@ class L1(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -862,7 +866,7 @@ class L2(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -895,8 +899,8 @@ class QuadraticFormRegularizer(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - kernel: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + kernel: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -940,10 +944,10 @@ class RBFKernel(Model): def __init__( self, - input1: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - input2: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - l_scale: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - sigma: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input1: Tensor[int | float | bool] | ToBeDetermined = TBD, + input2: Tensor[int | float | bool] | ToBeDetermined = TBD, + l_scale: Tensor[int | float | bool] | ToBeDetermined = TBD, + sigma: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1018,10 +1022,10 @@ class PolynomialKernel(Model): def __init__( self, robust: bool = True, - input1: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - input2: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - poly_coef: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - degree: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input1: Tensor[int | float | bool] | ToBeDetermined = TBD, + input2: Tensor[int | float | bool] | ToBeDetermined = TBD, + poly_coef: Tensor[int | float | bool] | ToBeDetermined = TBD, + degree: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1083,11 +1087,11 @@ class KernelizedSVM(Model): def __init__( self, kernel: BaseModel, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[TypeVarTensorType] | ToBeDetermined, + **kwargs: Tensor[int | float | bool] | ToBeDetermined, ) -> None: if len(kernel.input_keys) < 2: raise KeyError("Kernel requires at least two inputs!") @@ -1154,9 +1158,9 @@ class LinearSVM(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1205,9 +1209,9 @@ class LogisticRegression(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1256,10 +1260,10 @@ def __init__( activations: list[BaseModel], dimensions: Sequence[int | None], input_name_templates: dict[str, str] | None = None, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, - **weights_biases: Tensor[TypeVarTensorType] | ToBeDetermined, + **weights_biases: Tensor[int | float | bool] | ToBeDetermined, ) -> None: super().__init__(name=name) self.factory_args = {"activations": activations, "dimensions": dimensions} @@ -1364,12 +1368,12 @@ class RNNCell(Cell): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_ih: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_hh: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_ho: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_h: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_ih: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_hh: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_ho: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_h: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_o: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1491,17 +1495,17 @@ class LSTMCell(Cell): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_out: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_out: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_i: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_f: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_c: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_o: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_out: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_f: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_i: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_c: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_o: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_out: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1662,17 +1666,17 @@ class LSTMCellBody(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - prev_hidden: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - prev_cell: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - w_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_f: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_i: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_c: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias_o: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + prev_hidden: Tensor[int | float | bool] | ToBeDetermined = TBD, + prev_cell: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_i: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_f: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_c: Tensor[int | float | bool] | ToBeDetermined = TBD, + w_o: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_f: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_i: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_c: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias_o: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -1795,7 +1799,7 @@ def __init__( cell_type: Cell, *, name: str | None = None, - # **kwargs: Tensor[TypeVarTensorType] | MainValueType, + # **kwargs: Tensor[int | float | bool] | MainValueType, ) -> None: self.cell_type = cell_type super().__init__(name=name) @@ -1813,10 +1817,10 @@ def __init__( cell_type: Cell, max_sequence_length: int, teacher_forcing: bool = False, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[TypeVarTensorType] | MainValueType, + **kwargs: Tensor[int | float | bool] | MainValueType, ) -> None: super().__init__(cell_type=cell_type, name=name) @@ -1898,10 +1902,10 @@ def __init__( self, cell_type: Cell, max_sequence_length: int, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[TypeVarTensorType] | ToBeDetermined, + **kwargs: Tensor[int | float | bool] | ToBeDetermined, ) -> None: super().__init__(cell_type=cell_type, name=name) @@ -1955,10 +1959,10 @@ def __init__( self, cell_type: Cell, max_sequence_length: int, - hidden_concat: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + hidden_concat: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, - **kwargs: Tensor[TypeVarTensorType] | ToBeDetermined, + **kwargs: Tensor[int | float | bool] | ToBeDetermined, ) -> None: super().__init__(cell_type, name=name) @@ -2038,7 +2042,7 @@ def __init__( max_input_sequence_length: int, max_target_sequence_length: int, teacher_forcing: bool = False, - indices: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + indices: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2141,8 +2145,8 @@ def __init__( self, get_final_distance: bool = True, robust: bool = True, - input1: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - input2: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input1: Tensor[int | float | bool] | ToBeDetermined = TBD, + input2: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2201,9 +2205,9 @@ def __init__( self, degree: int, dimension: int | None = None, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - weight: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - bias: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + weight: Tensor[int | float | bool] | ToBeDetermined = TBD, + bias: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2245,9 +2249,9 @@ def __init__( self, exact_distances: bool = True, robust: bool = True, - distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - pred_distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + distances: Tensor[int | float | bool] | ToBeDetermined = TBD, + pred_distances: Tensor[int | float | bool] | ToBeDetermined = TBD, + norm: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2345,9 +2349,9 @@ def __init__( exact_distances: bool = True, calculate_p_joint: bool = False, perplexity: float = 20.0, - distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - pred_distances: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - p_joint: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + distances: Tensor[int | float | bool] | ToBeDetermined = TBD, + pred_distances: Tensor[int | float | bool] | ToBeDetermined = TBD, + p_joint: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2442,10 +2446,10 @@ def __init__( self, base_model: MDSCore | TSNECore, input_type: str = "distances", - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - predicted_coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + coords: Tensor[int | float | bool] | ToBeDetermined = TBD, + norm: Tensor[int | float | bool] | ToBeDetermined = TBD, + predicted_coords: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2551,10 +2555,10 @@ def __init__( self, prediction_dim: int, input_type: str = "distances", - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - predicted_coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + coords: Tensor[int | float | bool] | ToBeDetermined = TBD, + norm: Tensor[int | float | bool] | ToBeDetermined = TBD, + predicted_coords: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2609,9 +2613,9 @@ def __init__( input_type: str = "distances", preplexity: float = 20.0, calculate_p_joint: bool = False, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - norm: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - predicted_coords: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, + norm: Tensor[int | float | bool] | ToBeDetermined = TBD, + predicted_coords: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2665,14 +2669,14 @@ class GaussProcessRegressionCore(Model): def __init__( self, - s: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - k: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - k_star: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - mu: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - loss: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - prediction: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - confidence: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + s: Tensor[int | float | bool] | ToBeDetermined = TBD, + k: Tensor[int | float | bool] | ToBeDetermined = TBD, + k_star: Tensor[int | float | bool] | ToBeDetermined = TBD, + mu: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, + loss: Tensor[int | float | bool] | ToBeDetermined = TBD, + prediction: Tensor[int | float | bool] | ToBeDetermined = TBD, + confidence: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2792,11 +2796,11 @@ class GPRLoss(Model): def __init__( self, robust: bool = False, - labels: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - mu: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - L: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - K_term: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - alpha: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + labels: Tensor[int | float | bool] | ToBeDetermined = TBD, + mu: Tensor[int | float | bool] | ToBeDetermined = TBD, + L: Tensor[int | float | bool] | ToBeDetermined = TBD, + K_term: Tensor[int | float | bool] | ToBeDetermined = TBD, + alpha: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2884,8 +2888,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -2959,8 +2963,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3021,8 +3025,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3182,8 +3186,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3342,8 +3346,8 @@ def __init__( is_binary: bool = False, is_pred_one_hot: bool = True, is_label_one_hot: bool = True, - pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3498,8 +3502,8 @@ def __init__( self, n_classes: int, is_label_one_hot: bool = True, - pred: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, - label: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + pred: Tensor[int | float | bool] | ToBeDetermined = TBD, + label: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: @@ -3556,7 +3560,7 @@ class SiLU(Model): def __init__( self, - input: Tensor[TypeVarTensorType] | ToBeDetermined = TBD, + input: Tensor[int | float | bool] | ToBeDetermined = TBD, *, name: str | None = None, ) -> None: diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 09f540de..9c982c81 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -290,7 +290,7 @@ def model_to_dict(model: BaseModel) -> TrainModelDict | ModelDict: for key, con in model.conns.all.items(): edge = con.metadata - if get_origin(edge.edge_type) is Tensor and not con.is_key_autogenerated: + if edge.is_tensor and not con.is_key_autogenerated: differentiablility_info[key] = edge.differentiable for shape in model.assigned_shapes: @@ -387,7 +387,7 @@ def connection_to_dict( elif is_valued and connection in model.conns.input_connections: val = connection.metadata.value assert not isinstance(val, ToBeDetermined) - if get_origin(connection.metadata.edge_type) is Tensor: + if connection.metadata.is_tensor: val = {"tensor": val} if connection.key.startswith("$"): key_value = val diff --git a/mithril/utils/func_utils.py b/mithril/utils/func_utils.py index 31ba3d59..53f238b5 100644 --- a/mithril/utils/func_utils.py +++ b/mithril/utils/func_utils.py @@ -14,14 +14,13 @@ from collections.abc import Callable from copy import deepcopy -from typing import Any, get_origin +from typing import Any from ..core import DataType from ..framework.common import ( # , Scalar, Tensor TBD, DataEvalType, IOHyperEdge, - Tensor, ) KeyMapType = dict[str, str] @@ -162,7 +161,7 @@ def reorganize_args( def is_make_array_required(data: IOHyperEdge) -> bool: - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert data.shape is not None _temp_shape = next(iter(data.shape.reprs)) # It is needed to guarantee that Tensor is at least one dimensional. diff --git a/tests/scripts/helper.py b/tests/scripts/helper.py index 8c5d2f71..d3a315c5 100644 --- a/tests/scripts/helper.py +++ b/tests/scripts/helper.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import get_origin from mithril import Backend, Constant, compile, epsilon_table from mithril.framework.common import IOHyperEdge, Tensor @@ -57,7 +56,7 @@ def evaluate_case( model = finalize_model(current_case) # Convert static keys to array if they are not scalar. for key, value in static_keys.items(): - if get_origin(model.conns.get_metadata(key).edge_type) is not Tensor: + if not model.conns.get_metadata(key).is_tensor: static_keys[key] = value else: static_keys[key] = convert_to_array(backend, value) @@ -92,7 +91,7 @@ def evaluate_case( data_value = epsilon_table[backend.precision][data_value] assert data_value == copied_data.value - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert id(data.value) == id(copied_data.value) # Evaluate model. diff --git a/tests/scripts/test_constr_counter.py b/tests/scripts/test_constr_counter.py index eabc17f8..85544a8d 100644 --- a/tests/scripts/test_constr_counter.py +++ b/tests/scripts/test_constr_counter.py @@ -13,8 +13,6 @@ # limitations under the License. -from typing import get_origin - import pytest from mithril.framework.common import ( @@ -50,12 +48,8 @@ def dummy_constraint(output: IOHyperEdge, input: IOHyperEdge): # updated_symbols = set() updates = Updates() status = False - output_repr = ( - output._temp_shape if get_origin(output.edge_type) is Tensor else output.value - ) - input_repr = ( - input._temp_shape if get_origin(input.edge_type) is Tensor else input.value - ) + output_repr = output._temp_shape if output.is_tensor else output.value + input_repr = input._temp_shape if input.is_tensor else input.value assert isinstance(output_repr, ShapeRepr) assert isinstance(input_repr, ShapeRepr) if bool(input_repr.root) ^ bool(output_repr.root): diff --git a/tests/scripts/test_constraints.py b/tests/scripts/test_constraints.py index b66518e2..9dd0b30c 100644 --- a/tests/scripts/test_constraints.py +++ b/tests/scripts/test_constraints.py @@ -15,7 +15,7 @@ from collections.abc import Callable, Mapping from copy import deepcopy from types import EllipsisType, NoneType, UnionType -from typing import Any, TypeGuard, get_origin +from typing import Any, TypeGuard import numpy as np import pytest @@ -218,7 +218,7 @@ def assert_shape_results( shapes = {} assignments: AssignmentType = {} for key, value in data.items(): - if get_origin(value.edge_type) is Tensor: + if value.is_tensor: assert value.shape is not None shapes[key] = value.shape.get_shapes(uni_cache, var_cache, verbose=True) shape_repr = value._temp_shape @@ -264,7 +264,7 @@ def assert_value_results( assert data[key].value == value else: # If value is a tensor of any supported backend. - assert get_origin(data[key].edge_type) is Tensor + assert data[key].is_tensor d_val = data[key].value assert GenericDataType.is_tensor_type(d_val) assert (d_val == value).all() diff --git a/tests/scripts/test_hyperedge.py b/tests/scripts/test_hyperedge.py index b9e60f56..8fe62930 100644 --- a/tests/scripts/test_hyperedge.py +++ b/tests/scripts/test_hyperedge.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import get_origin import pytest @@ -35,7 +34,7 @@ def test_init_with_tensor_default_type(): edge = IOHyperEdge(Tensor[int | float | bool]) assert ( - get_origin(edge.edge_type) is Tensor + edge.is_tensor and isinstance(edge._value, Tensor) and edge.value_type == int | float | bool and edge.value is TBD @@ -47,7 +46,7 @@ def test_init_with_tensor_default_type(): def test_init_with_tensor_int_or_float_type(): edge = IOHyperEdge(Tensor[int | float]) assert ( - get_origin(edge.edge_type) is Tensor + edge.is_tensor and isinstance(edge._value, Tensor) and edge.value_type == int | float and edge.value is TBD @@ -59,7 +58,7 @@ def test_init_with_tensor_int_or_float_type(): def test_init_with_tensor_type_tensor_value(): edge = IOHyperEdge(Tensor[int | float], value=Tensor([[2.0]])) assert ( - get_origin(edge.edge_type) is Tensor + edge.is_tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -116,11 +115,7 @@ def test_set_tensor_type(): assert edge.edge_type is ToBeDetermined and edge._value is TBD and edge.value is TBD assert edge.shape is None edge.set_type(Tensor[int | float | bool]) - assert ( - get_origin(edge.edge_type) is Tensor - and isinstance(edge._value, Tensor) - and edge.value is TBD - ) + assert edge.is_tensor and isinstance(edge._value, Tensor) and edge.value is TBD assert edge.value_type == int | float | bool assert isinstance(edge.shape, ShapeNode) assert edge in edge._value.referees and edge in edge._value.shape.referees @@ -131,11 +126,7 @@ def test_set_generic_tensor_type(): assert edge.edge_type is ToBeDetermined and edge._value is TBD and edge.value is TBD assert edge.shape is None edge.set_type(Tensor[int | float]) - assert ( - get_origin(edge.edge_type) is Tensor - and isinstance(edge._value, Tensor) - and edge.value is TBD - ) + assert edge.is_tensor and isinstance(edge._value, Tensor) and edge.value is TBD assert edge.value_type == int | float assert isinstance(edge.shape, ShapeNode) assert edge in edge._value.referees and edge in edge._value.shape.referees @@ -179,7 +170,7 @@ def test_init_with_tensor_value(): tensor: Tensor[float] = Tensor([[2.0]], shape=shape_node) edge = IOHyperEdge(value=tensor) assert ( - get_origin(edge.edge_type) is Tensor + edge.is_tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -198,7 +189,7 @@ def test_set_non_typed_edge_with_tensor_value(): edge = IOHyperEdge() edge.set_value(tensor) assert ( - get_origin(edge.edge_type) is Tensor + edge.is_tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] @@ -221,7 +212,7 @@ def test_set_tensor_edge_with_tensor_value(): edge_tensor = edge._value edge.set_value(tensor) assert ( - get_origin(edge.edge_type) is Tensor + edge.is_tensor and isinstance(edge._value, Tensor) and edge.value_type is float and edge.value == [[2.0]] diff --git a/tests/scripts/test_io_key.py b/tests/scripts/test_io_key.py index e0e8acef..ae5835de 100644 --- a/tests/scripts/test_io_key.py +++ b/tests/scripts/test_io_key.py @@ -13,7 +13,6 @@ # limitations under the License. from itertools import product -from typing import get_origin import numpy as np import pytest @@ -663,7 +662,7 @@ def test_iokey_values_12(): model += sig_model_2( input=IOKey(shape=[1, 2, 3, 4], name="input"), output=IOKey(name="output2") ) - assert get_origin(sig_model_1.input.data.metadata.edge_type) is Tensor + assert sig_model_1.input.metadata.is_tensor assert sig_model_1.input.data.metadata.shape is not None assert sig_model_1.input.data.metadata.shape.get_shapes() == [1, 2, 3, 4] diff --git a/tests/scripts/test_randomized_models_all_backends.py b/tests/scripts/test_randomized_models_all_backends.py index 93c46666..f485765e 100644 --- a/tests/scripts/test_randomized_models_all_backends.py +++ b/tests/scripts/test_randomized_models_all_backends.py @@ -15,7 +15,6 @@ import inspect import json from copy import deepcopy -from typing import get_origin import numpy as np import pytest @@ -23,7 +22,6 @@ import mithril as ml from mithril import JaxBackend, MlxBackend, NumpyBackend, TorchBackend, compile, models from mithril.backends.utils import DtypeBits -from mithril.framework.common import Tensor from mithril.utils.dict_conversions import dict_to_model from tests.scripts.test_utils import ( dict_to_random, @@ -180,7 +178,7 @@ def test_randomized(case: str) -> None: ) static_inputs[init_key] = { key: init_backend.array(value) - if get_origin(model.conns.get_metadata(key).edge_type) is Tensor + if model.conns.get_metadata(key).is_tensor else value for key, value in static_inputs[init_key].items() } @@ -230,7 +228,7 @@ def test_randomized(case: str) -> None: } static_inputs[backend.backend_type] = { key: backend.array(value) - if get_origin(model.conns.get_metadata(key).edge_type) is Tensor + if model.conns.get_metadata(key).is_tensor else value for key, value in static_inputs[init_key].items() } diff --git a/tests/scripts/test_ref_counts.py b/tests/scripts/test_ref_counts.py index 6b788370..3c19d632 100644 --- a/tests/scripts/test_ref_counts.py +++ b/tests/scripts/test_ref_counts.py @@ -15,7 +15,6 @@ import sys from copy import deepcopy -from typing import get_origin from mithril.framework.common import ( NOT_GIVEN, @@ -90,8 +89,8 @@ def __init__(self) -> None: submodel1 = TestModel() submodel2 = TestModel() - assert get_origin(submodel1.output.metadata.edge_type) is Tensor - assert get_origin(submodel2.output.metadata.edge_type) is Tensor + assert submodel1.output.metadata.is_tensor + assert submodel2.output.metadata.is_tensor assert submodel1.output.metadata.shape is not None assert submodel2.input.metadata.shape is not None ref_var1 = next(iter(submodel1.output.metadata.shape.reprs)).root @@ -122,8 +121,8 @@ def __init__(self) -> None: buff_model1 = MyModel() buff_model2 = MyModel() - assert get_origin(buff_model1.input.metadata.edge_type) is Tensor - assert get_origin(buff_model2.input.metadata.edge_type) is Tensor + assert buff_model1.input.metadata.is_tensor + assert buff_model2.input.metadata.is_tensor assert buff_model1.input.metadata.shape is not None assert buff_model2.output.metadata.shape is not None ref_var1 = next(iter(buff_model1.input.metadata.shape.reprs)).root @@ -1379,8 +1378,8 @@ def __init__(self) -> None: buff_model1 = MyModel() buff_model2 = MyModel() - assert get_origin(buff_model1.output.metadata.edge_type) is Tensor - assert get_origin(buff_model2.input.metadata.edge_type) is Tensor + assert buff_model1.output.metadata.is_tensor + assert buff_model2.input.metadata.is_tensor assert buff_model1.output.metadata.shape is not None assert buff_model2.input.metadata.shape is not None ref_var1 = next(iter(buff_model1.output.metadata.shape.reprs))[0] @@ -1392,7 +1391,7 @@ def __init__(self) -> None: diff_roots = set() for tensor in get_all_data(model): - assert get_origin(tensor.edge_type) is Tensor + assert tensor.is_tensor node = tensor.shape assert node is not None for repr in node.reprs: @@ -1420,8 +1419,8 @@ def __init__(self) -> None: submodel1 = MyModel() submodel2 = MyModel() - assert get_origin(submodel1.output.metadata.edge_type) is Tensor - assert get_origin(submodel2.input.metadata.edge_type) is Tensor + assert submodel1.output.metadata.is_tensor + assert submodel2.input.metadata.is_tensor assert submodel1.output.metadata.shape is not None assert submodel2.input.metadata.shape is not None ref_var1 = next(iter(submodel1.output.metadata.shape.reprs))[0] @@ -1452,8 +1451,8 @@ def __init__(self) -> None: submodel1 = MyModel() submodel2 = MyModel() - assert get_origin(submodel1.output.metadata.edge_type) is Tensor - assert get_origin(submodel2.input.metadata.edge_type) is Tensor + assert submodel1.output.metadata.is_tensor + assert submodel2.input.metadata.is_tensor assert submodel1.output.metadata.shape is not None assert submodel2.input.metadata.shape is not None ref_var1 = next(iter(submodel1.output.metadata.shape.reprs)) @@ -1484,8 +1483,8 @@ def __init__(self) -> None: submodel1 = MyModel() submodel2 = MyModel() - assert get_origin(submodel1.output.metadata.edge_type) is Tensor - assert get_origin(submodel2.input.metadata.edge_type) is Tensor + assert submodel1.output.metadata.is_tensor + assert submodel2.input.metadata.is_tensor ref_var1 = submodel1.output.metadata.shape ref_var2 = submodel2.input.metadata.shape diff --git a/tests/scripts/test_scripts.py b/tests/scripts/test_scripts.py index 3b34418b..7d8bbc39 100644 --- a/tests/scripts/test_scripts.py +++ b/tests/scripts/test_scripts.py @@ -19,7 +19,6 @@ import typing from copy import deepcopy from functools import partial -from typing import get_origin import jax import mlx.core as mx @@ -5799,7 +5798,7 @@ def test_deepcopy_1(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert id(data.value) == id(copied_data.value) @@ -5827,7 +5826,7 @@ def test_deepcopy_2(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert id(data.value) == id(copied_data.value) @@ -5855,7 +5854,7 @@ def test_deepcopy_3(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert id(data.value) == id(copied_data.value) @@ -5880,7 +5879,7 @@ def test_deepcopy_4(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert id(data.value) == id(copied_data.value) @@ -5916,7 +5915,7 @@ def test_deepcopy_5(): if copied_data not in unused_data: assert isinstance(copied_data, IOHyperEdge) assert data.value == copied_data.value - if get_origin(data.edge_type) is Tensor: + if data.is_tensor: assert id(data.value) == id(copied_data.value) diff --git a/tests/scripts/test_shapes.py b/tests/scripts/test_shapes.py index ffe14582..b6dd2160 100644 --- a/tests/scripts/test_shapes.py +++ b/tests/scripts/test_shapes.py @@ -16,7 +16,6 @@ from copy import deepcopy from itertools import combinations, product from types import EllipsisType, NoneType -from typing import get_origin import numpy as np import pytest @@ -4437,7 +4436,7 @@ def test_total_repr_count(): edge = var2.input.data.metadata - assert get_origin(edge.edge_type) is Tensor + assert edge.is_tensor assert edge.shape is not None assert len(edge.shape.reprs) == 2 @@ -4445,7 +4444,7 @@ def test_total_repr_count(): def test_total_repr_count_linear_1(): model = Linear() edge = model.input.metadata - assert get_origin(edge.edge_type) is Tensor + assert edge.is_tensor assert edge.shape is not None shp_repr = next(iter(edge.shape.reprs)) @@ -5724,12 +5723,12 @@ def __call__( # type: ignore[override] model.set_shapes(shape_1) model.set_shapes(shape_2) in_data = model.input.metadata - assert get_origin(in_data.edge_type) is Tensor + assert in_data.is_tensor assert (node := in_data.shape) is not None input_repr = next(iter(node.reprs)) out_data = model.output.metadata - assert get_origin(out_data.edge_type) is Tensor + assert out_data.is_tensor assert out_data.shape is not None output_repr = next(iter(out_data.shape.reprs)) @@ -5768,10 +5767,10 @@ def __call__( # type: ignore[override] model.set_shapes(shape_1) model.set_shapes(shape_2) - assert get_origin(model.input.metadata.edge_type) is Tensor + assert model.input.metadata.is_tensor assert (in_node := model.input.metadata.shape) is not None input_repr = next(iter(in_node.reprs)) - assert get_origin(model.output.metadata.edge_type) is Tensor + assert model.output.metadata.is_tensor assert (out_node := model.output.metadata.shape) is not None output_repr = next(iter(out_node.reprs)) @@ -5856,7 +5855,7 @@ def test_same_uniadic_5(): buffer.set_shapes(shape_1) buffer.set_shapes(shape_2) - assert get_origin(buffer.input.metadata.edge_type) is Tensor + assert buffer.input.metadata.is_tensor assert buffer.input.metadata.shape is not None input_reprs = buffer.input.metadata.shape.reprs @@ -6724,9 +6723,7 @@ def find_all_reprs(repr: ShapeRepr, repr_cache=None) -> set[ShapeRepr]: # find connections only with tensor data all_tensor_conns = { - con - for con in model.conns.all.values() - if get_origin(con.metadata.edge_type) is Tensor + con for con in model.conns.all.values() if con.metadata.is_tensor } # Find all reprs that are linked to shape reprs of the tensors @@ -6864,18 +6861,18 @@ def test_node_count_1(): # Check total existing node count all_nodes = set() for con in model.conns.all.values(): - assert get_origin(con.metadata.edge_type) is Tensor + assert con.metadata.is_tensor all_nodes.add(con.metadata.shape) assert len(all_nodes) == 1 # Check total variadics repr count - assert get_origin(sub_model.input.metadata.edge_type) is Tensor + assert sub_model.input.metadata.is_tensor assert (in_node := sub_model.input.metadata.shape) is not None assert (in_repr := next(iter(in_node.reprs))) is not None assert in_repr.root is not None assert len(in_repr.root.reprs) == 1 - assert get_origin(sub_model.output.metadata.edge_type) is Tensor + assert sub_model.output.metadata.is_tensor assert (out_node := sub_model.output.metadata.shape) is not None assert (out_repr := next(iter(out_node.reprs))) is not None assert out_repr.root is not None @@ -6895,7 +6892,7 @@ def test_node_count_2(): all_nodes = set() for con in model.conns.all.values(): edge = con.metadata - assert get_origin(edge.edge_type) is Tensor + assert edge.is_tensor all_nodes.add(edge.shape) assert len(all_nodes) == 2 @@ -6917,7 +6914,7 @@ def test_node_count_3(): shapes = set() for con in model.conns.all.values(): edge = con.metadata - assert get_origin(edge.edge_type) is Tensor + assert edge.is_tensor shapes.add(edge.shape) assert len(shapes) == 3 @@ -7456,7 +7453,7 @@ def __call__( # type: ignore[override] model += Buffer() model += test_model all_nodes = get_all_nodes(model) - assert get_origin(buff_model.input.metadata.edge_type) is Tensor + assert buff_model.input.metadata.is_tensor ref_all_nodes = { test_model.output.metadata.shape, # type: ignore buff_model.input.metadata.shape, @@ -7496,7 +7493,7 @@ def __call__( # type: ignore[override] all_nodes = get_all_nodes(model) data = buff_model.input.metadata - assert get_origin(data.edge_type) is Tensor + assert data.is_tensor ref_all_nodes = {data.shape} assert all_nodes == ref_all_nodes @@ -7512,7 +7509,7 @@ def test_node_count_6(): all_nodes = get_all_nodes(model) data = buff_model.input.metadata - assert get_origin(data.edge_type) is Tensor + assert data.is_tensor ref_all_nodes = {data.shape} assert all_nodes == ref_all_nodes @@ -7527,7 +7524,7 @@ def test_node_count_7(): for _ in range(5): model += deepcopy(model) all_nodes = get_all_nodes(model) - assert get_origin(buff_model.input.metadata.edge_type) is Tensor + assert buff_model.input.metadata.is_tensor ref_all_nodes = {buff_model.input.metadata.shape} assert all_nodes == ref_all_nodes @@ -7721,7 +7718,7 @@ def __call__( # type: ignore[override] all_nodes = get_all_nodes(model) data = test_model.input.metadata - assert get_origin(data.edge_type) is Tensor + assert data.is_tensor ref_all_nodes = {data.shape} assert all_nodes == ref_all_nodes @@ -7750,7 +7747,7 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert get_origin(test_model.input.metadata.edge_type) is Tensor + assert test_model.input.metadata.is_tensor ref_all_nodes = {test_model.input.metadata.shape} assert all_nodes == ref_all_nodes @@ -7780,9 +7777,9 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert get_origin(test_model1.input.metadata.edge_type) is Tensor - assert get_origin(test_model2.input.metadata.edge_type) is Tensor - assert get_origin(test_model3.input.metadata.edge_type) is Tensor + assert test_model1.input.metadata.is_tensor + assert test_model2.input.metadata.is_tensor + assert test_model3.input.metadata.is_tensor ref_all_nodes = { test_model1.input.metadata.shape, test_model2.input.metadata.shape, @@ -7816,9 +7813,9 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert get_origin(test_model1.input.metadata.edge_type) is Tensor - assert get_origin(test_model2.input.metadata.edge_type) is Tensor - assert get_origin(test_model3.input.metadata.edge_type) is Tensor + assert test_model1.input.metadata.is_tensor + assert test_model2.input.metadata.is_tensor + assert test_model3.input.metadata.is_tensor ref_all_nodes = { test_model1.input.metadata.shape, test_model2.input.metadata.shape, @@ -7852,9 +7849,9 @@ def __call__( # type: ignore[override] model += MyModel() all_nodes = get_all_nodes(model) - assert get_origin(test_model1.input.metadata.edge_type) is Tensor - assert get_origin(test_model2.input.metadata.edge_type) is Tensor - assert get_origin(test_model3.input.metadata.edge_type) is Tensor + assert test_model1.input.metadata.is_tensor + assert test_model2.input.metadata.is_tensor + assert test_model3.input.metadata.is_tensor ref_all_nodes = { test_model1.input.metadata.shape, test_model2.input.metadata.shape, @@ -7893,7 +7890,7 @@ def test_uniadic_repr_count_2(): } model.set_shapes(shape_1) - assert get_origin(buff_model1.input.metadata.edge_type) is Tensor + assert buff_model1.input.metadata.is_tensor data_shape = buff_model1.input.metadata.shape assert data_shape is not None @@ -7928,7 +7925,7 @@ def test_uniadic_repr_count_3(): } ) - assert get_origin(buff_model1.input.metadata.edge_type) is Tensor + assert buff_model1.input.metadata.is_tensor data_shape = buff_model1.input.metadata.shape assert data_shape is not None @@ -8037,7 +8034,7 @@ def test_uniadic_repr_count_5(): model.set_shapes(shapes) - assert get_origin(buff_model1.input.metadata.edge_type) is Tensor + assert buff_model1.input.metadata.is_tensor data_shape = buff_model1.input.metadata.shape assert data_shape is not None @@ -8232,7 +8229,7 @@ def test_possible_uniadic_values_directed_8(): buff_model = Buffer() buff_model.set_shapes({"input": ["a", "b"]}) - assert get_origin(buff_model.input.metadata.edge_type) is Tensor + assert buff_model.input.metadata.is_tensor data_shape = buff_model.input.metadata.shape assert data_shape is not None @@ -8252,7 +8249,7 @@ def test_possible_uniadic_values_directed_9(): buff_model = Buffer() buff_model.set_shapes({"input": ["a", "b", "c", "d"]}) - assert get_origin(buff_model.input.metadata.edge_type) is Tensor + assert buff_model.input.metadata.is_tensor data_shape = buff_model.input.metadata.shape assert data_shape is not None @@ -9716,7 +9713,7 @@ def test_remove_variadic(): with pytest.raises(Exception) as err_info: # model.shape_map["output"].remove_variadic([Uniadic(5)]) data = model.conns.get_data("output") - assert get_origin(data.edge_type) is Tensor + assert data.is_tensor data_shape = data.shape assert data_shape is not None next(iter(data_shape.reprs)).remove_variadic([Uniadic(5)]) @@ -9733,7 +9730,7 @@ def test_bcast_left(): } model.set_shapes(shape_1) - assert get_origin(model.output.metadata.edge_type) is Tensor + assert model.output.metadata.is_tensor data_shape = model.output.metadata.shape assert data_shape is not None assert data_shape.get_shapes() == [2, "u1", "(V1, ...)"] @@ -9760,7 +9757,7 @@ def test_bcast_left_2(): } model.set_shapes(shape_1) - assert get_origin(model.output.metadata.edge_type) is Tensor + assert model.output.metadata.is_tensor data_shape = model.output.metadata.shape assert data_shape is not None @@ -9780,7 +9777,7 @@ def test_bcast_left_3(): } ) - assert get_origin(model.output.metadata.edge_type) is Tensor + assert model.output.metadata.is_tensor data_shape = model.output.metadata.shape assert data_shape is not None diff --git a/tests/scripts/test_utils.py b/tests/scripts/test_utils.py index 305c0051..bcb8111a 100644 --- a/tests/scripts/test_utils.py +++ b/tests/scripts/test_utils.py @@ -15,7 +15,7 @@ import inspect import re from collections.abc import Callable, Mapping, Sequence -from typing import Any, get_origin +from typing import Any import numpy as np @@ -295,10 +295,7 @@ def get_all_nodes(model: BaseModel): # recursively gets the all shape in the model (ShapeNode) all_data = get_all_data(model) node_set = { - data.shape - for data in all_data - if get_origin(data.edge_type) is Tensor - if data.shape is not None + data.shape for data in all_data if data.is_tensor if data.shape is not None } return node_set