diff --git a/benchmarks/speed_benchmarks/jax_fns.py b/benchmarks/speed_benchmarks/jax_fns.py index 27423f5d..170e1ccd 100644 --- a/benchmarks/speed_benchmarks/jax_fns.py +++ b/benchmarks/speed_benchmarks/jax_fns.py @@ -66,7 +66,7 @@ def setup(self): def __call__(self, inputs): x = inputs for lyr, actv in zip(self.layers, self.jax_activations, strict=False): - x = lyr(x) # type: ignore + x = lyr(x) x = actv(x) # type: ignore return x diff --git a/examples/model_api/linear_regression_jax_training.py b/examples/model_api/linear_regression_jax_training.py index 058f2205..0769b2b1 100644 --- a/examples/model_api/linear_regression_jax_training.py +++ b/examples/model_api/linear_regression_jax_training.py @@ -58,5 +58,5 @@ for i in range(num_epochs): outputs, gradients = pm.evaluate_all(params) updates, opt_state = optimizer.update(gradients, opt_state) - params = optax.apply_updates(params, updates) # type: ignore + params = optax.apply_updates(params, updates) print(f"Epoch: {i} / {num_epochs} -> ", outputs["final_cost"]) diff --git a/examples/model_api/variable_length_many_to_one_lstm.py b/examples/model_api/variable_length_many_to_one_lstm.py index 1671ed61..335e3ff3 100644 --- a/examples/model_api/variable_length_many_to_one_lstm.py +++ b/examples/model_api/variable_length_many_to_one_lstm.py @@ -67,8 +67,8 @@ target_end = int(input_end + target_lengths[idx]) # NOTE: Pylance sees int, int type arguments but throws an error. - single_input = backend.arange(start, input_end).reshape(-1, input_dim) # type: ignore - single_target = backend.arange(input_end, target_end).reshape(-1, output_dim) # type: ignore + single_input = backend.arange(start, input_end).reshape(-1, input_dim) + single_target = backend.arange(input_end, target_end).reshape(-1, output_dim) single_data = (single_input, single_target) train_data.append(single_data) @@ -150,7 +150,7 @@ # Prepare the test input data. test_input = backend.arange( starting_number, - starting_number + inference_max_input, # type: ignore + starting_number + inference_max_input, ).reshape(-1, input_dim) # Prepare the test data. @@ -172,7 +172,7 @@ # Prepare target values. test_target_values = backend.arange( - starting_number + inference_max_input, # type: ignore + starting_number + inference_max_input, starting_number + inference_max_input + inference_max_target_length, ) @@ -204,4 +204,4 @@ ) # Measure test error. -error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() # type: ignore +error = backend.abs(unpacked_output_data.squeeze() - test_target_values).sum() diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index c9569e4d..cba36ed7 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -57,14 +57,15 @@ def __init__(self, precision: int = 32, device: str = "cpu") -> None: # setattr(self, key, value) @property - def precision(self): + def precision(self) -> int: return self._precision + #!! @property - def device(self): + def device(self) -> Any: return self._device - def get_device(self): + def get_device(self) -> str: return self._device @property @@ -72,11 +73,11 @@ def inf(self) -> DataType | float: raise NotImplementedError("inf is not implemented") @property - def pi(self): + def pi(self) -> float: return math.pi @property - def e(self): + def e(self) -> float: return math.e @property @@ -104,7 +105,7 @@ def to_device( def block_until_ready(self, data: DataType) -> DataType | None: raise RuntimeError("Backend does not support block_until_ready method!") - def empty_cache(self): # noqa: B027 + def empty_cache(self) -> None: # noqa: B027 pass # print("Warning: empty_cache is not supported!") @@ -126,7 +127,7 @@ def cast(self, value: Any) -> Any: return value - def __del__(self): + def __del__(self) -> None: self.empty_cache() @overload diff --git a/mithril/backends/parallel.py b/mithril/backends/parallel.py index ea8451a5..1e454750 100644 --- a/mithril/backends/parallel.py +++ b/mithril/backends/parallel.py @@ -37,10 +37,10 @@ def run_callable(self, *primals: Any, fn_name: str) -> dict[str, Any]: @abstractmethod def parallelize( self, tensor: DataType, device_mesh: tuple[int, ...] | None = None - ) -> dict[str, Any]: + ) -> DataType: raise NotImplementedError() - def clean_up(self): + def clean_up(self) -> None: self.callables = dict() self.device_mesh = None self.n_devices = -1 diff --git a/mithril/backends/with_autograd/common_primitives.py b/mithril/backends/with_autograd/common_primitives.py index 50e3a064..98cd6000 100644 --- a/mithril/backends/with_autograd/common_primitives.py +++ b/mithril/backends/with_autograd/common_primitives.py @@ -63,80 +63,80 @@ ] -def greater(left: DataType, right: DataType): +def greater(left: DataType, right: DataType) -> DataType: return left > right -def greater_equal(left: DataType, right: DataType): +def greater_equal(left: DataType, right: DataType) -> DataType: return left >= right -def less(left: DataType, right: DataType): +def less(left: DataType, right: DataType) -> DataType: return left < right -def less_equal(left: DataType, right: DataType): +def less_equal(left: DataType, right: DataType) -> DataType: return left <= right -def equal(left: DataType, right: DataType): - return left == right +def equal(left: DataType, right: DataType) -> DataType: + return left == right # type: ignore -def not_equal(left: DataType, right: DataType): - return left != right +def not_equal(left: DataType, right: DataType) -> DataType: + return left != right # type: ignore -def logical_not(input: DataType): +def logical_not(input: DataType) -> DataType: return ~input -def logical_or(left: DataType, right: DataType): - return left | right +def logical_or(left: DataType, right: DataType) -> DataType: + return left | right # type: ignore -def logical_and(left: DataType, right: DataType): - return left & right +def logical_and(left: DataType, right: DataType) -> DataType: + return left & right # type: ignore -def matrix_multiplication(left: DataType, right: DataType): - return left @ right +def matrix_multiplication(left: DataType, right: DataType) -> DataType: + return left @ right # type: ignore -def add(left: DataType, right: DataType): - return left + right +def add(left: DataType, right: DataType) -> DataType: + return left + right # type: ignore -def subtract(left: DataType, right: DataType): - return left - right +def subtract(left: DataType, right: DataType) -> DataType: + return left - right # type: ignore -def multiplication(left: DataType, right: DataType): - return left * right +def multiplication(left: DataType, right: DataType) -> DataType: + return left * right # type: ignore -def divide(numerator: DataType, denominator: DataType): - return numerator / denominator +def divide(numerator: DataType, denominator: DataType) -> DataType: + return numerator / denominator # type: ignore -def floor_divide(numerator: DataType, denominator: DataType): - return numerator // denominator +def floor_divide(numerator: DataType, denominator: DataType) -> DataType: + return numerator // denominator # type: ignore -def shift_left(input: DataType, shift: DataType): - return input << shift +def shift_left(input: DataType, shift: DataType) -> DataType: + return input << shift # type: ignore -def shift_right(input: DataType, shift: DataType): - return input >> shift +def shift_right(input: DataType, shift: DataType) -> DataType: + return input >> shift # type: ignore -def power(base: DataType, exponent: DataType): - return base**exponent +def power(base: DataType, exponent: DataType) -> DataType: + return base**exponent # type: ignore -def squared_error(input: DataType, target: DataType): - return (input - target) ** 2 +def squared_error(input: DataType, target: DataType) -> DataType: + return (input - target) ** 2 # type: ignore def minus(input: DataType) -> DataType: @@ -148,18 +148,18 @@ def transpose( ) -> DataType: if not axes: return input.T - return input.transpose(*axes) + return input.transpose(*axes) # type: ignore -def swapaxes(input: DataType, axis1: int, axis2: int): +def swapaxes(input: DataType, axis1: int, axis2: int) -> DataType: return input.swapaxes(axis1, axis2) -def square(input: DataType): - return input * input +def square(input: DataType) -> DataType: + return input * input # type: ignore -def buffer(input: DataType): +def buffer(input: DataType) -> DataType: return input @@ -168,18 +168,20 @@ def permute_tensor(input: DataType, indices: DataType) -> DataType: def reshape(input: DataType, shape: tuple[int, ...]) -> DataType: - return input.reshape(shape) + return input.reshape(shape) # type: ignore def item(input: DataType) -> int | float | bool: return input.item() # type: ignore -def tensor_item(input: DataType, index: int | slice | tuple[int | slice, ...]): - return input[index] +def tensor_item( + input: DataType, index: int | slice | tuple[int | slice, ...] +) -> DataType: + return input[index] # type: ignore -def primitive_slice(start: int | None, stop: int | None, step: int | None): +def primitive_slice(start: int | None, stop: int | None, step: int | None) -> slice: return slice(start, stop, step) @@ -187,8 +189,8 @@ def length(input: DataType) -> int: return len(input) -def cartesian_diff(left: DataType, right: DataType): - return left[:, None, :] - right[None, :, :] +def cartesian_diff(left: DataType, right: DataType) -> DataType: + return left[:, None, :] - right[None, :, :] # type: ignore def primitive_embedding(input: DataType, weight: DataType) -> DataType: @@ -218,11 +220,11 @@ def union(*inputs: int | float | tuple[int | float, ...]) -> tuple[int | float, return result -def to_tuple(*args: tuple[int | float | bool, ...]): +def to_tuple(*args: int | float | bool) -> tuple[int | float | bool, ...]: return tuple(args) -def to_list(*args: tuple[int | float | bool, ...]): +def to_list(*args: int | float | bool) -> list[int | float | bool]: return list(args) @@ -291,7 +293,7 @@ def padding_converter_2d( def stride_converter( input: int | PaddingType | tuple[int, int] | None, kernel_size: int | tuple[int, int], -): +) -> int | tuple[int, int] | PaddingType: if input is None: return kernel_size else: @@ -303,7 +305,7 @@ def tuple_converter( | PaddingType | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]], -): +) -> tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] | PaddingType: if isinstance(input, int): return (input, input) else: diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 5b009c77..ee1e63ad 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -75,14 +75,14 @@ def is_manualgrad(self) -> bool: return False @property - def inf(self): + def inf(self) -> float: return jax.numpy.inf @property - def nan(self): + def nan(self) -> float: return jax.numpy.nan - def get_backend_array_type(self): + def get_backend_array_type(self) -> type[jax.Array]: return jax.Array @property @@ -98,7 +98,7 @@ def DataType(self): # noqa: N802 return utils.ArrayType @staticmethod - def get_available_devices(): + def get_available_devices() -> list[str]: """Static method to get a list of available devices. Parameters @@ -112,7 +112,7 @@ def get_available_devices(): def register_primitive(fn: Callable[..., Any]) -> None: JaxBackend.registered_primitives[fn.__name__] = fn - def set_seed(self, seed: int): + def set_seed(self, seed: int) -> None: self.seed = seed self.prng_key = jax.random.PRNGKey(seed) @@ -145,7 +145,7 @@ def block_until_ready(self, data: jax.Array) -> jax.Array | None: def register_callable( self, fn: Callable[..., Any], fn_name: str, jit: bool = False - ): + ) -> None: assert ( self._parallel_manager is not None ), "Parallel manager is not initialized!" @@ -153,7 +153,7 @@ def register_callable( fn_name = str(id(self)) + fn_name return self._parallel_manager.register_callable(fn, fn_name, jit) - def _run_callable(self, *primals: jax.Array, fn_name: str): + def _run_callable(self, *primals: jax.Array, fn_name: str) -> Any: assert ( self._parallel_manager is not None ), "Parallel manager is not initialized!" @@ -161,7 +161,7 @@ def _run_callable(self, *primals: jax.Array, fn_name: str): fn_name = str(id(self)) + fn_name return self._parallel_manager.run_callable(*primals, fn_name=fn_name) - def _create_parallel(self, device_mesh: tuple[int, ...]): + def _create_parallel(self, device_mesh: tuple[int, ...]) -> None: self._parallel_manager = JaxParallel(math.prod(device_mesh), self._device) def array( @@ -538,7 +538,9 @@ def multinomial( return samples - def jit(self, *args: Any, **kwargs: Any): + def jit( # type: ignore[override] + self, *args: Any, **kwargs: Any + ) -> Callable[..., jax.Array | tuple[jax.Array, ...]] | dict[str, jax.Array]: return jax.jit(*args, **kwargs) def grad( diff --git a/mithril/backends/with_autograd/jax_backend/ops.py b/mithril/backends/with_autograd/jax_backend/ops.py index e51a718b..54246fbf 100644 --- a/mithril/backends/with_autograd/jax_backend/ops.py +++ b/mithril/backends/with_autograd/jax_backend/ops.py @@ -424,7 +424,7 @@ def conv2d( stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), -): +) -> jax.Array: _padding_normalized: tuple[tuple[int, int], tuple[int, int]] if is_tuple_int(padding): _padding_normalized = ((padding[0], padding[0]), (padding[1], padding[1])) @@ -451,7 +451,7 @@ def conv2d_bias( stride: tuple[int, int] = (1, 1), padding: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] = (1, 1), dilation: tuple[int, int] = (1, 1), -): +) -> jax.Array: return ( conv2d( input=input, @@ -559,7 +559,7 @@ def scaled_dot_product_attention( dropout_p: float = 0.0, is_causal: bool = False, scale: float | int | None = None, -): +) -> jax.Array: if dropout_p != 0.0: raise RuntimeError( "Currently Jax scaled_dot_product_attention only support dropout_p 0" @@ -632,7 +632,7 @@ def cross_entropy( f"Cross entropy got unexpected type for target '{target.dtype}'." ) - return ( + return ( # type: ignore -log(jnp.take_along_axis(input, target[:, None], axis=1)[:, 0]) * _weights[target] ) diff --git a/mithril/backends/with_autograd/jax_backend/parallel.py b/mithril/backends/with_autograd/jax_backend/parallel.py index 2ae2c9bd..901c182e 100644 --- a/mithril/backends/with_autograd/jax_backend/parallel.py +++ b/mithril/backends/with_autograd/jax_backend/parallel.py @@ -32,12 +32,12 @@ def __init__(self, n_devices: int, device: str) -> None: ) super().__init__(n_devices) - def run_callable(self, *primals: jax.Array, fn_name: str): + def run_callable(self, *primals: jax.Array, fn_name: str) -> Any: return self.callables[fn_name](*primals) def parallelize( self, tensor: jax.Array, device_mesh: tuple[int, ...] | None = None - ): + ) -> jax.Array: # Jax reuqires math.prod(device_mesh) == n_devices. To replicate a dimension # call 'replicate' method of Positional Sharding Object. Therefore, we need to # transform user provided device mesh to the one that satisfies the condition, @@ -67,11 +67,13 @@ def parallelize( return jax.device_put(tensor, sharding) - def register_callable(self, fn: Callable[..., Any], fn_name: str, jit: bool): + def register_callable( + self, fn: Callable[..., Any], fn_name: str, jit: bool + ) -> None: if jit: fn = jax.jit(fn) self.callables[fn_name] = fn - def clean_up(self): + def clean_up(self) -> None: self.callables = {} diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index 9c5139b6..367d2e90 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -78,7 +78,7 @@ def robust_power_helper( input1: jax.Array, input2: jax.Array, threshold: jax.Array ) -> jax.Array: def cond_fun(cond: jax.Array, input1: jax.Array, input2: jax.Array) -> jax.Array: - return jax.lax.cond( + return jax.lax.cond( # type: ignore cond, robust_power_under_threshold, robust_power_above_threshold, @@ -284,7 +284,7 @@ def polynomial_features_helper(x: jax.Array, y: jax.Array) -> jax.Array: ) -def get_available_devices(): +def get_available_devices() -> list[str]: backends: set[str] = set(jax._src.xla_bridge.backends()) - set(["interpreter"]) devices = [ f"{backend.replace('METAL','mps')}:{idx}" diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index a80f0c4d..dfa50e8b 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -53,7 +53,7 @@ def is_manualgrad(self) -> bool: return False @property - def inf(self): + def inf(self) -> float: return mx.inf @property diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index 369f1882..3710f7d1 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -29,7 +29,7 @@ class CBackend(Backend[PyArray]): type = "c" SRC_PATH = "mithril/backends/with_manualgrad/c_backend/src" - def __init__(self): + def __init__(self) -> None: self._precision = 32 self._device = "cpu" self.primitive_function_dict = {} diff --git a/mithril/backends/with_manualgrad/c_backend/src/array.pyi b/mithril/backends/with_manualgrad/c_backend/src/array.pyi index fd0a17a4..e36bb5c0 100644 --- a/mithril/backends/with_manualgrad/c_backend/src/array.pyi +++ b/mithril/backends/with_manualgrad/c_backend/src/array.pyi @@ -40,6 +40,7 @@ class PyArray: def __le__(self, other: PyArray) -> PyArray: ... def __and__(self, other: PyArray) -> PyArray: ... def __or__(self, other: PyArray) -> PyArray: ... + def __ror__(self, other: PyArray) -> PyArray: ... def __xor__(self, other: PyArray) -> PyArray: ... def __invert__(self) -> PyArray: ... def __matmul__(self, other: PyArray) -> PyArray: ... diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops.py b/mithril/backends/with_manualgrad/numpy_backend/ops.py index 22b4bb50..087a1669 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops.py @@ -20,8 +20,8 @@ from typing import Any import numpy as np -import scipy.linalg as slin # type: ignore[import-untyped] -from scipy.special import erf # type: ignore[import-untyped] +import scipy.linalg as slin +from scipy.special import erf from .... import core from ....utils.type_utils import is_tuple_int diff --git a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py index d4584aa5..a0e3bf31 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py +++ b/mithril/backends/with_manualgrad/numpy_backend/ops_grad.py @@ -17,8 +17,8 @@ from typing import Any import numpy as np -import scipy.linalg as slin # type: ignore[import-untyped] -from scipy.special import erf # type: ignore[import-untyped] +import scipy.linalg as slin +from scipy.special import erf from ....utils.type_utils import is_tuple_int from .ops import hinge_loss, sigmoid, softmax diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index 42d00762..2d8bccbf 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -85,7 +85,7 @@ def write_into_cache[T: np.ndarray[Any, Any] | tuple[Any, ...] | int | float]( else: result = cache[key] # TODO: Resolve here - return result # type: ignore + return result def get_submatrices1d( diff --git a/mithril/framework/codegen/c_gen.py b/mithril/framework/codegen/c_gen.py index aba85b5e..470c65f5 100644 --- a/mithril/framework/codegen/c_gen.py +++ b/mithril/framework/codegen/c_gen.py @@ -250,7 +250,7 @@ def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: grad_inputs = [input_key + "_grad" for input_key in inputs] for idx in range(len(grad_inputs)): fn_inputs: list[str] = ( - [output_key + "_grad", c_ast.Constant(idx), output_key] + [output_key + "_grad", c_ast.Constant(idx).to_str(), output_key] + inputs + grad_inputs ) @@ -277,7 +277,7 @@ def generate_evaluate_gradients(self) -> tuple[c_ast.FunctionDef, set[str]]: return evaluate_grad_fn, used_keys - def _get_backend_path(self): + def _get_backend_path(self) -> str: backend_path = backend.__file__ return backend_path[: backend_path.rindex("/")] diff --git a/mithril/framework/codegen/python_gen.py b/mithril/framework/codegen/python_gen.py index 97750ffa..a85257f8 100644 --- a/mithril/framework/codegen/python_gen.py +++ b/mithril/framework/codegen/python_gen.py @@ -21,10 +21,10 @@ from typing import Any, Generic, Literal, Protocol, overload from ...backends.backend import ParallelBackend +from ...core import DataType from ...utils.func_utils import prepare_function_args from ..common import ( DataEvalType, - DataType, EvaluateAllType, EvaluateGradientsType, EvaluateType, @@ -104,7 +104,7 @@ def __init__(self, pm: PhysicalModel[DataType]) -> None: self.globals: list[ast.stmt] = [] self.functions: list[ast.stmt] = [] - def generate_code(self, file_path: str | None = None): + def generate_code(self, file_path: str | None = None) -> None: self.file_path = file_path self.imports += self.generate_imports() self.functions += self.generate_functions() @@ -119,10 +119,10 @@ def generate_code(self, file_path: str | None = None): if file_path is not None: self.write_code(file_path) - def generate_functions(self): + def generate_functions(self) -> list[ast.FunctionDef]: return [self.generate_evaluate()] - def write_code(self, file_path: str): + def write_code(self, file_path: str) -> None: if self.code is None: raise Exception( "Code is not generated yet! Please call generate_code() first." @@ -151,7 +151,7 @@ def exec_generated_code( if self.file_path is not None: module_name = splitext(basename(self.file_path))[0] - module_spec = importlib.util.spec_from_file_location( # type: ignore + module_spec = importlib.util.spec_from_file_location( module_name, self.file_path ) module = importlib.util.module_from_spec(module_spec) # type: ignore @@ -242,12 +242,12 @@ def post_process_fns( evaluate_all_fn is not None ), "Evaluate all function is not defined!" - grad_fn = self.pm.backend.jit(grad_fn) + grad_fn = self.pm.backend.jit(grad_fn) # type: ignore evaluate_all_fn = self.pm.backend.jit(evaluate_all_fn) - return eval_fn, grad_fn, evaluate_all_fn + return eval_fn, grad_fn, evaluate_all_fn # type: ignore - def import_backend(self): + def import_backend(self) -> ast.ImportFrom: backend = ast.ImportFrom( module="mithril", names=[ @@ -261,7 +261,7 @@ def import_backend(self): return backend - def generate_imports(self): + def generate_imports(self) -> list[ast.stmt]: imports: list[ast.stmt] = [] # Add import primitive functions imports.append( @@ -295,7 +295,9 @@ def generate_imports(self): return imports - def get_primitive_details(self, output_key: str): + def get_primitive_details( + self, output_key: str + ) -> tuple[PrimitiveModel, list[str], list[str]]: model = self.pm.flat_graph.get_model(output_key) global_input_keys = self.pm.flat_graph.get_source_keys(output_key) @@ -311,7 +313,7 @@ def call_primitive( g_input_keys: list[str], output_key: str, formula_key: str, - ): + ) -> tuple[ast.Assign, set[str]]: generated_fn, used_keys = self.create_primitive_call( fn, l_input_keys, g_input_keys ) @@ -321,9 +323,9 @@ def call_primitive( if formula_key in self.pm.backend.array_creation_funcs: self.add_partial_function(formula_key) - return ast.Assign(targets, generated_fn), used_keys | _used_keys # type: ignore + return ast.Assign(targets, generated_fn), used_keys | _used_keys - def generate_evaluate(self): + def generate_evaluate(self) -> ast.FunctionDef: input_body: list[ast.stmt] = [] function_body: list[ast.stmt] = [] return_values: list[ast.expr] = [] @@ -424,7 +426,9 @@ def generate_evaluate(self): return ast.fix_missing_locations(func_def) - def append_inputs(self, input_body: list[ast.stmt], key: str, dict_type: str): + def append_inputs( + self, input_body: list[ast.stmt], key: str, dict_type: str + ) -> None: # In manual_grad type backends, cache contains all the required # data (local variables and outputs) for the corresponding function. # So if the key is not directly an output of a function get it from @@ -538,7 +542,7 @@ def create_primitive_call_targets( return targets, {target_name} - def add_partial_function(self, formula_key: str): + def add_partial_function(self, formula_key: str) -> None: if formula_key in self.defined_partial_fns: return @@ -561,7 +565,7 @@ def create_gradient_fn( # raw_evaluate_grad_fn: ManualGradWrapperFn[DataType] | None, raw_evaluate_fn: RawEvaluateType[DataType], raw_evaluate_grad_fn: ManualGradWrapperFn[DataType] | None, - ): + ) -> tuple[ManualGradWrapperFn[DataType], RawEvaluateType[DataType]]: fn_all: EvaluateAllType[DataType] grad_fn: EvaluateGradientsType[DataType] if not self.pm.backend.is_manualgrad: @@ -572,20 +576,20 @@ def create_gradient_fn( include_output=False, ) # Fix fn_all for mlx support!! - fn_all = partial( + fn_all = partial( # type: ignore self.compute_gradients, raw_evaluate_fn=raw_evaluate_fn, cache=self.pm.data_store.data_values, include_output=True, ) - return grad_fn, fn_all + return grad_fn, fn_all # type: ignore else: assert raw_evaluate_grad_fn is not None, "Gradient function is not defined!" fn_all = partial(raw_evaluate_grad_fn, include_output=True) # type: ignore grad_fn = partial(raw_evaluate_grad_fn, include_output=False) # type: ignore - return grad_fn, fn_all + return grad_fn, fn_all # type: ignore @overload def compute_gradients( diff --git a/mithril/framework/codegen/torch_gen.py b/mithril/framework/codegen/torch_gen.py index 1b326de2..b3298a0a 100644 --- a/mithril/framework/codegen/torch_gen.py +++ b/mithril/framework/codegen/torch_gen.py @@ -40,7 +40,7 @@ def call_primitive( g_input_keys: list[str], output_key: str, formula_key: str, - ): + ) -> tuple[ast.Assign, set[str]]: generated_fn, used_keys = self.create_primitive_call( fn, l_input_keys, g_input_keys ) diff --git a/mithril/framework/common.py b/mithril/framework/common.py index 074f5a43..4b848b51 100644 --- a/mithril/framework/common.py +++ b/mithril/framework/common.py @@ -68,7 +68,7 @@ "Tensor", "Scalar", "ShapesType", - "_ShapesType", + "ShapeResultType", "get_summary_shapes", "get_summary_types", "ConstraintSolver", @@ -195,6 +195,7 @@ class KeyType(Enum): | tuple[Any, ...] | list[Any] | dict[Any, Any] + | Mapping[Any, Any] | Constant | slice | PaddingType @@ -212,6 +213,7 @@ class KeyType(Enum): | tuple[Any, ...] | list[Any] | dict[Any, Any] + | Mapping[Any, Any] | bool | None | EllipsisType @@ -291,7 +293,7 @@ def __call__( def update_equivalence_table( item1: ItemType, item2: ItemType, lookup_table: dict[ItemType, set[ItemType]] -): +) -> None: item_set1 = lookup_table.get(item1) item_set2 = lookup_table.get(item2) if item_set1 is None and item_set2 is None: @@ -321,7 +323,7 @@ class ConstraintSolver: default_factory=lambda: {} ) - def __call__(self, updates: Updates): + def __call__(self, updates: Updates) -> None: self.update_shapes(updates) solved_constrs: set[Constraint] = set() for constr_type in UpdateType: @@ -332,7 +334,7 @@ def _solver_loop( constraint_type: UpdateType, updates: Updates, solved_constraints: set[Constraint], - ): + ) -> None: constraints = updates.constraints[constraint_type] while constraints: constr = constraints.pop() @@ -370,14 +372,14 @@ def _solver_loop( constraints.discard(constr) @staticmethod - def _combine_nodes(updates: Updates): + def _combine_nodes(updates: Updates) -> None: # Check if any node could be reduced after variadic updates add into # node_updates field. while updates.node_updates: node = updates.node_updates.pop() updates |= node.combine() - def _reduce_uniadic_referees(self, updates: Updates): + def _reduce_uniadic_referees(self, updates: Updates) -> None: while updates.uniadic_updates: uni = updates.uniadic_updates.pop() uni_val = uni.value @@ -445,7 +447,7 @@ def _add_sublists( return updates - def clear(self): + def clear(self) -> None: self.symbol_store = {} self.constraint_map = {} self.empty_node = None @@ -479,7 +481,7 @@ def _delete_node(remaining: ShapeNode, deleted: ShapeNode) -> Updates: deleted.reprs = [] return updates - def update_shapes(self, updates: Updates): + def update_shapes(self, updates: Updates) -> None: deletion_nodes: dict[ShapeNode, set[ShapeNode]] = {} # Note that update can be tuple also. First element of update # is always Tensor | Scalar. So this is the case, get first element @@ -732,7 +734,7 @@ def match(self, other: BaseData[T]) -> Updates: self.finalize_match(other) return updates - def set_value(self, value: AllValueType) -> Updates: # type: ignore[override] + def set_value(self, value: AllValueType) -> Updates: updates = Updates() if self.value is not TBD and self.value != value: raise ValueError( @@ -1203,7 +1205,7 @@ def __hash__(self) -> int: | Mapping[str, ShapeTemplateType] | Mapping[Connection, ShapeTemplateType] ) -_ShapesType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None] +ShapeResultType = Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None] @dataclass @@ -1398,7 +1400,7 @@ def get_shape_node(self, key: str) -> ShapeNode: return data.shape def set_value(self, con: ConnectionData, value: MainValueType): - self.get_data(con.key).set_value(value) # type: ignore + self.get_data(con.key).set_value(value) def extract_metadata(self, key: str | Connection) -> IOHyperEdge: if isinstance(key, Connection): @@ -2744,7 +2746,7 @@ def __call__(self, keys: list[Scalar | Tensor]) -> ConstrainResultType: def add_post_process(self, fn: ConstraintFunctionType): self.post_processes.add(fn) - def create_post_constraints(self): + def create_post_constraints(self) -> set[Constraint]: constraints: set[Constraint] = set() for fn in self.post_processes: constraints.add(Constraint(fn, self.type)) @@ -2902,7 +2904,7 @@ def compile( # adjust the table accordingly self._adjust_table() # calculate total table width - table_width = reduce( # type: ignore + table_width = reduce( partial(add_lengths, const=(len(row) for row in row_sep)), self.each_row_width, ) @@ -2953,7 +2955,7 @@ def compile( self.header_str = header_str self.cell_str = cell_str - def display(self): + def display(self) -> None: """Prints the table""" print(self.header_str) print(self.cell_str) @@ -2964,7 +2966,7 @@ def construct_subtable_row( arg_max_lengths: list[int], adjustments: list[str], *args: list[list[str]], - ): + ) -> list[str]: # Constructs subtable with given args subtable_list: list[str] = [] elems: tuple[list[str], ...] @@ -3404,7 +3406,7 @@ def get_summary( ) cell = input_table + sep + output_table - table.add_row([model_name, cell]) + table.add_row([[model_name], cell]) subheader_adjustments = ["left", "left", "left", "left", "left"] subheader_adjustments = [ @@ -3425,10 +3427,10 @@ def get_summary( def get_summary_shapes( - model_shapes: dict[str, _ShapesType], + model_shapes: dict[str, ShapeResultType], conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]], ): - shape_info: dict[str, tuple[_ShapesType, _ShapesType]] = {} + shape_info: dict[str, tuple[ShapeResultType, ShapeResultType]] = {} for model_name in conn_info: shape = model_shapes[model_name] input_conns, output_conns = conn_info[model_name] diff --git a/mithril/framework/constraints.py b/mithril/framework/constraints.py index 6c71ad05..6f46b8d3 100644 --- a/mithril/framework/constraints.py +++ b/mithril/framework/constraints.py @@ -105,7 +105,7 @@ # Below functions are used in various constraints. def prod_fn(a: int | Uniadic, b: int | Uniadic) -> int: - return (a if isinstance(a, int) else a.value) * ( # type: ignore + return (a if isinstance(a, int) else a.value) * ( b if isinstance(b, int) else b.value ) @@ -495,7 +495,7 @@ def scalar_item_reduce_input_type( inner_type.append(output_type) if arg is tuple: inner_type.append(...) - possible_types.append(arg[*inner_type]) # type: ignore + possible_types.append(arg[*inner_type]) return create_union_type(*possible_types) elif isinstance(input_type, GenericAlias): input_origin = input_type.__origin__ diff --git a/mithril/framework/logical/base.py b/mithril/framework/logical/base.py index c518ee6d..cc499728 100644 --- a/mithril/framework/logical/base.py +++ b/mithril/framework/logical/base.py @@ -15,7 +15,7 @@ from __future__ import annotations import abc -from collections.abc import Mapping +from collections.abc import KeysView, Mapping from dataclasses import dataclass from itertools import chain from types import UnionType @@ -33,7 +33,6 @@ ConstraintSolver, IOHyperEdge, MainValueType, - NestedListType, NotAvailable, Scalar, ShapeNode, @@ -49,6 +48,7 @@ get_shapes, ) from ..constraints import post_process_map, type_constraints +from ..utils import NestedListType __all__ = ["BaseModel", "ExtendInfo"] @@ -58,7 +58,7 @@ class ExtendInfo: _model: BaseModel _connections: dict[str, ConnectionType] - def __post_init__(self): + def __post_init__(self) -> None: external_keys = set(self._model.external_keys) if self._model.canonical_input is not NOT_AVAILABLE: external_keys.add(self._model.canonical_input.key) @@ -70,11 +70,11 @@ def __post_init__(self): raise KeyError(f"Key '{key}' is not a valid key for the model!") @property - def model(self): + def model(self) -> BaseModel: return self._model @property - def connections(self): + def connections(self) -> dict[str, ConnectionType]: return self._connections @@ -134,23 +134,25 @@ def jittable(self) -> bool: return self._jittable @property - def shapes(self): + def shapes( + self, + ) -> Mapping[str, ShapeTemplateType | list[ShapeTemplateType] | None]: return self.get_shapes() @property - def external_keys(self): + def external_keys(self) -> KeysView[str]: return self.conns.io_keys @property - def input_keys(self): + def input_keys(self) -> KeysView[str]: return self.conns.input_keys @property - def _all_keys(self): + def _all_keys(self) -> KeysView[str]: return self.conns.all.keys() @property - def output_keys(self): + def output_keys(self) -> list[str]: output_keys = list(self.conns.output_keys) if ( self.canonical_output is not NOT_AVAILABLE @@ -159,12 +161,12 @@ def output_keys(self): output_keys.append("#canonical_output") return output_keys - def check_extendability(self): + def check_extendability(self) -> None: # Check possible errors before the extension. if self.parent is not None: raise AttributeError("Submodel of a model could not be extended!") - def _get_outermost_parent(self): + def _get_outermost_parent(self) -> BaseModel: model = self while model.parent is not None: model = model.parent @@ -178,7 +180,7 @@ def generate_keys( ) -> dict[str, str]: return {} - def __setattr__(self, name: str, value: Any): + def __setattr__(self, name: str, value: Any) -> None: # You need to be careful here to avoid infinite recursion if ( getattr(self, "frozen_attributes", None) is not None @@ -282,7 +284,7 @@ def _set_value(self, key: ConnectionData, value: MainValueType | str) -> Updates if key.key not in self.conns.input_keys: raise ValueError("Values of internal and output keys cannot be set.") # Data is scalar, set the value directly. - return key.metadata.data.set_value(value) # type: ignore + return key.metadata.data.set_value(value) def set_shapes( self, config: ShapesType | None = None, **kwargs: ShapeTemplateType @@ -351,7 +353,7 @@ def set_types( | Mapping[str, type | UnionType | NestedListType] | None = None, **kwargs: type | UnionType | NestedListType, - ): + ) -> None: """ Set types of any connection in the Model @@ -404,7 +406,7 @@ def _set_constraint( keys: list[str], post_processes: set[ConstraintFunctionType] | None = None, type: UpdateType | None = None, - ): + ) -> None: all_conns = self.conns.all hyper_edges = [all_conns[key].metadata for key in keys] if type is None: @@ -453,7 +455,7 @@ def canonical_output(self) -> Connection | NotAvailable: else: return self._canonical_output.conn - def set_canonical_input(self, given_conn: str | Connection): + def set_canonical_input(self, given_conn: str | Connection) -> None: if isinstance(given_conn, str): conn = self.conns.all.get(given_conn) if conn is None: @@ -471,7 +473,7 @@ def set_canonical_input(self, given_conn: str | Connection): self._canonical_input = conn - def set_canonical_output(self, given_conn: str | Connection): + def set_canonical_output(self, given_conn: str | Connection) -> None: if isinstance(given_conn, str): conn = self.conns.all.get(given_conn) if conn is None: @@ -521,7 +523,7 @@ def _match_hyper_edges(self, left: IOHyperEdge, right: IOHyperEdge) -> Updates: updates = left.data.match(right.data) # type: ignore return updates - def get_models_in_topological_order(self): + def get_models_in_topological_order(self) -> list[BaseModel]: dependency_map = self.dependency_map.local_output_dependency_map graph = { info[0]: OrderedSet( @@ -542,7 +544,7 @@ def _reverse_dfs( graph: dict[BaseModel, OrderedSet[BaseModel]], top_order: list[BaseModel], visited: set[BaseModel], - ): + ) -> None: visited.add(node) for m in graph[node]: if m not in visited: @@ -586,7 +588,9 @@ def __init__(self, connections: Connections) -> None: ] = {} # Add new model to dependency map, model_dag is created in extend - def add_model_dag(self, model: BaseModel, model_dag: dict[str, ConnectionData]): + def add_model_dag( + self, model: BaseModel, model_dag: dict[str, ConnectionData] + ) -> None: updated_conns: OrderedSet[ConnectionData] = OrderedSet() for local_key, conn in model_dag.items(): if local_key in model.conns.input_keys: @@ -629,7 +633,7 @@ def add_model_dag(self, model: BaseModel, model_dag: dict[str, ConnectionData]): # Caches extended connections to avoid traverse def cache_internal_references( self, output_conn: ConnectionData, dependent_conns: OrderedSet[ConnectionData] - ): + ) -> None: # Be sure all input and output keys has cache entry for conn in self.conns.input_connections: self._global_input_dependency_map_cache.setdefault(conn, OrderedSet()) @@ -696,13 +700,13 @@ def cache_internal_references( ) # Caches given input connection for later usage - def cache_conn_input_dependency(self, conn: ConnectionData): + def cache_conn_input_dependency(self, conn: ConnectionData) -> None: if conn not in self._global_input_dependency_map_cache: dependents = self.get_output_key_dependency(conn.key) self._global_input_dependency_map_cache[conn] = dependents # Caches given output connection for later usage - def cache_conn_output_dependency(self, conn: ConnectionData): + def cache_conn_output_dependency(self, conn: ConnectionData) -> None: if conn not in self._global_output_dependency_map_cache: dependents = self.get_input_key_dependency(conn.key) self._global_output_dependency_map_cache[conn] = dependents @@ -740,7 +744,7 @@ def get_dependent_output_conns(self, key: str) -> OrderedSet[ConnectionData]: return dependent_conns # Update dependecy map - def update_all_keys(self): + def update_all_keys(self) -> None: # This method is used in freeze, because in freeze dependencies changed # without updating dependency map. self.update_globals( @@ -765,7 +769,9 @@ def _get_from_input_cache(self, conn: ConnectionData) -> OrderedSet[ConnectionDa # Get dependent input connections if given output connection is cached # else returns None - def _get_from_output_cache(self, conn: ConnectionData): + def _get_from_output_cache( + self, conn: ConnectionData + ) -> OrderedSet[ConnectionData]: dependent_conns = self._global_output_dependency_map_cache.get( conn, OrderedSet() ) @@ -779,7 +785,7 @@ def _get_from_output_cache(self, conn: ConnectionData): return dependent_conns # Update global dependency maps wrt given connections - def update_globals(self, updated_conns: OrderedSet[ConnectionData]): + def update_globals(self, updated_conns: OrderedSet[ConnectionData]) -> None: for input_conn in self.conns.input_connections: self._global_input_dependency_map.setdefault(input_conn, OrderedSet()) @@ -822,7 +828,7 @@ def update_globals(self, updated_conns: OrderedSet[ConnectionData]): updated_conns |= OrderedSet(dependent_conns) # Retrieve dependent output connection keys given input key by traversing the graph. - def get_input_key_dependency(self, key: str): + def get_input_key_dependency(self, key: str) -> OrderedSet[ConnectionData]: if (given_conn := self.conns.get_connection(key)) is None: raise KeyError("Given key does not belong to the Model!") # If there already exists any input keys, add them. @@ -864,7 +870,7 @@ def get_input_key_dependency(self, key: str): return specs # Retrieve dependent input connection keys given output key by traversing the graph. - def get_output_key_dependency(self, key: str): + def get_output_key_dependency(self, key: str) -> OrderedSet[ConnectionData]: if (given_conn := self.conns.get_connection(key)) is None: raise KeyError("Given key does not belong to the Model!") @@ -911,7 +917,9 @@ def look_for_cyclic_connection( ) return False - def merge_global_connections(self, conn1: ConnectionData, conn2: ConnectionData): + def merge_global_connections( + self, conn1: ConnectionData, conn2: ConnectionData + ) -> None: conn1_global_out_dependency = self._global_output_dependency_map.get(conn1) conn2_global_out_dependency = self._global_output_dependency_map.pop( conn2, None @@ -954,7 +962,7 @@ def merge_global_connections(self, conn1: ConnectionData, conn2: ConnectionData) self._global_output_dependency_map[dependent_conn].remove(conn2) self._global_output_dependency_map[dependent_conn].add(conn1) - def merge_global_caches(self, conn1: ConnectionData, conn2: ConnectionData): + def merge_global_caches(self, conn1: ConnectionData, conn2: ConnectionData) -> None: conn1_global_out_cache = self._global_output_dependency_map_cache.get(conn1) conn2_global_out_cache = self._global_output_dependency_map_cache.pop( conn2, None diff --git a/mithril/framework/logical/model.py b/mithril/framework/logical/model.py index 82ffd489..eff3dc65 100644 --- a/mithril/framework/logical/model.py +++ b/mithril/framework/logical/model.py @@ -307,11 +307,11 @@ def _convert_to_iokey( assert local_connection is not None, "Connection is not found!" match connection: case NullConnection(): - connection = IOKey() + _connection = IOKey() case str(): - connection = IOKey(name=connection) + _connection = IOKey(name=connection) case Connection(): - connection = IOKey(connections={connection}) + _connection = IOKey(connections={connection}) case ExtendTemplate(): # Unroll ExtendTemplate template_conn = model.conns.get_connection(key) @@ -319,7 +319,7 @@ def _convert_to_iokey( con_data = self._unroll_template( connection, type(template_conn.metadata.data) ) - connection = IOKey(connections={con_data.conn}, expose=False) + _connection = IOKey(connections={con_data.conn}, expose=False) case _ if isinstance(connection, MainValueInstance): # find_dominant_type returns the dominant type in a container. # If a container has a value of type Connection or ExtendTemplate @@ -338,10 +338,10 @@ def _convert_to_iokey( result = conv_model.conns.get_connection("output") assert result is not None - connection = IOKey(connections={result.conn}, expose=None) + _connection = IOKey(connections={result.conn}, expose=None) else: assert isinstance(connection, MainValueInstance) - connection = IOKey(value=connection) + _connection = IOKey(value=connection) case IOKey(): expose = connection.expose name = connection.name @@ -354,7 +354,7 @@ def _convert_to_iokey( and connection.connections == set() ): expose = True - connection = IOKey( + _connection = IOKey( name=name, expose=expose, connections=connection.connections, @@ -370,7 +370,7 @@ def _convert_to_iokey( "provide connection/key explicitly, or set canonical connections." ) - return connection + return _connection def update_key_name(self, connection: ConnectionData, key: str) -> None: for key_type in KeyType: diff --git a/mithril/framework/logical/primitive.py b/mithril/framework/logical/primitive.py index d1f44a63..cff62418 100644 --- a/mithril/framework/logical/primitive.py +++ b/mithril/framework/logical/primitive.py @@ -61,7 +61,7 @@ def __init__( super().__init__(name=name) - self._random_keys: set[str] = set() + self.random_keys: set[str] = set() # Get shape_templates of TensorTypes and create corresponding shapes. shape_templates = { key: value.shape @@ -99,7 +99,7 @@ def __init__( assert isinstance(value, BaseKey) _value = Tensor( shape=shapes[key].node, - possible_types=get_mytensor_subtype(value_type), # type: ignore + possible_types=get_mytensor_subtype(value_type), value=value.value, # type: ignore interval=value.interval, ) @@ -109,7 +109,7 @@ def __init__( else: _value = Scalar( possible_types=value_type, # type: ignore - value=value.value, # type: ignore + value=value.value, ) conn_data = self.create_connection(IOHyperEdge(_value), key) diff --git a/mithril/framework/physical/data_store.py b/mithril/framework/physical/data_store.py index b73bffcd..7b53a9cd 100644 --- a/mithril/framework/physical/data_store.py +++ b/mithril/framework/physical/data_store.py @@ -17,14 +17,13 @@ from typing import Any, Generic, TypeGuard from ...backends.backend import Backend -from ...core import DataType, data_types, epsilon_table +from ...core import Constant, DataType, data_types, epsilon_table from ...utils.func_utils import is_make_array_required, prepare_function_args from ...utils.utils import BiMap from ..common import ( TBD, AllValueType, Connection, - Constant, ConstraintSolver, DataEvalType, MainValueInstance, @@ -55,7 +54,7 @@ def __init__( self.graph: FlatGraph[DataType] = graph self.backend: Backend[DataType] = backend self.inference = inference - self._intermediate_non_differentiables: BiMap[str, Tensor | Scalar] = BiMap() + self.intermediate_non_differentiables: BiMap[str, Tensor | Scalar] = BiMap() # type: ignore self._runtime_static_keys: set[str] = set() self._unused_keys: set[str] = set() # Final tensor values of data store. @@ -64,11 +63,11 @@ def __init__( self._random_seeds: dict[str, int] = dict() @property - def all_data(self): + def all_data(self) -> dict[str, Tensor | Scalar]: return self._all_data @property - def cached_data(self): + def cached_data(self) -> DataEvalType[DataType]: return self.data_values @property @@ -91,22 +90,20 @@ def unused_keys(self) -> set[str]: def is_scalar_type(t: Any) -> TypeGuard[MainValueType]: return not isinstance(t, tuple(data_types)) - def remove_keys_from_store(self, keys: set[str]): + def remove_keys_from_store(self, keys: set[str]) -> None: keys -= set(self.graph.output_keys) for key in keys: self.remove_key_from_store(key, label_as_unused=False, hard_remove=True) def remove_key_from_store( self, key: str, label_as_unused: bool = True, hard_remove: bool = False - ): - # Remove key from all attributes. + ) -> None: if key in self.data_values: self.data_values.pop(key) # type: ignore self._runtime_static_keys.discard(key) - - if key in self._intermediate_non_differentiables: - self._intermediate_non_differentiables.pop(key) + if key in self.intermediate_non_differentiables: + self.intermediate_non_differentiables.pop(key) if key in self._random_seeds: self._random_seeds.pop(key) @@ -121,7 +118,7 @@ def remove_key_from_store( self._all_data.pop(key) self._clear_constraints(key) - def _clear_constraints(self, key: str): + def _clear_constraints(self, key: str) -> None: if key not in self._all_data: return @@ -132,16 +129,16 @@ def _clear_constraints(self, key: str): self._all_data[source_key].shape_constraints -= shape_constraints self._all_data[source_key].type_constraints -= type_constraints - def _update_cached_data(self, updated_data: Updates) -> set[str]: + def update_cached_data(self, updated_data: Updates) -> set[str]: # If any data value is found by shape inference algorithms # transfer this data in cached_data. transferred_keys: set[str] = set() updated_inter_data = ( updated_data.value_updates - & self._intermediate_non_differentiables.inverse.keys() + & self.intermediate_non_differentiables.inverse.keys() ) for data in updated_inter_data: - key = self._intermediate_non_differentiables.inverse[data] + key = self.intermediate_non_differentiables.inverse[data] if key in self.data_values or data.value is not TBD: if key in self.data_values: raise KeyError( @@ -152,10 +149,10 @@ def _update_cached_data(self, updated_data: Updates) -> set[str]: self._set_data_value(key, data) transferred_keys.add(key) for key in transferred_keys: - self._intermediate_non_differentiables.pop(key) + self.intermediate_non_differentiables.pop(key) return transferred_keys - def _set_data_value(self, key: str, data: Tensor | Scalar): + def _set_data_value(self, key: str, data: Tensor | Scalar) -> None: value: DataType | AllValueType = data.value assert not isinstance(value, ToBeDetermined) if isinstance(data, Tensor): @@ -166,7 +163,7 @@ def _set_data_value(self, key: str, data: Tensor | Scalar): self.data_values[key] = value # type: ignore - def _infer_unused_keys(self, key: str): + def infer_unused_keys(self, key: str) -> None: # Infers unused keys when "key" is set as static. output_keys = self.graph.output_keys queue = set(self.graph.get_source_keys(key, True)) @@ -206,11 +203,11 @@ def set_shapes( updates |= data.shape.set_values(value) self.constraint_solver(updates) # Some intermediate values may be calculated, update cached data. - new_statics = self._update_cached_data(updates) + new_statics = self.update_cached_data(updates) for key in new_statics: - self._infer_unused_keys(key) + self.infer_unused_keys(key) - def update_data(self, data: dict[str, Tensor | Scalar]): + def update_data(self, data: dict[str, Tensor | Scalar]) -> None: if data.keys() & self._all_data.keys(): raise Exception("Some keys are already in data store!") self._all_data |= data @@ -236,7 +233,7 @@ def update_data(self, data: dict[str, Tensor | Scalar]): if value.value is not TBD: self._set_data_value(key, value) else: - self._intermediate_non_differentiables[key] = value + self.intermediate_non_differentiables[key] = value def set_static_keys( self, @@ -283,11 +280,11 @@ def add_static_data( if isinstance(data, Tensor): assert not isinstance(value, MainValueInstance) # Find shape of tensor and set. - shape = list(value.shape) + shape = list(value.shape) # type: ignore updates |= data.shape.set_values(shape) # Find type of tensor and set. val_type: type[bool] | type[int] | type[float] - data_dtype = str(value.dtype) + data_dtype = str(value.dtype) # type: ignore # Check value type is OK, and update type accordinly. if "bool" in data_dtype: val_type = bool @@ -301,7 +298,7 @@ def add_static_data( "Only float, int or bool types are accepted." ) updates |= data.set_type(val_type) - elif isinstance(data, Scalar) and self.is_scalar_type(value): + elif self.is_scalar_type(value): updates |= data.set_value(value) else: raise ValueError( @@ -309,18 +306,18 @@ def add_static_data( f"the type of data: {type(data)}!" ) self.data_values[key] = value # type: ignore - self._intermediate_non_differentiables.pop(key, None) + self.intermediate_non_differentiables.pop(key, None) if ( - key not in self._intermediate_non_differentiables + key not in self.intermediate_non_differentiables and key in self.runtime_static_keys ): self._runtime_static_keys.remove(key) # Finally update cached_data, infer unused keys and # return newly added static keys. self.constraint_solver(updates) - statics = self._update_cached_data(updates) | updated_keys + statics = self.update_cached_data(updates) | updated_keys for static in statics: - self._infer_unused_keys(static) + self.infer_unused_keys(static) return statics, updates @@ -408,20 +405,17 @@ def infer_static_keys(self) -> Updates: updates |= _updates return updates - def set_random_seed_keys(self, seed_keys: set[str]): + def set_random_seed_keys(self, seed_keys: set[str]) -> None: for key in seed_keys: if self.all_data[key].value == TBD: self._random_seeds[key] = 0 else: - self._random_seeds[key] = self.all_data[key].value + value = self.all_data[key].value + assert isinstance(value, int) + self._random_seeds[key] = value - def set_random_seed_values(self, **seed_mapping: int): + def set_random_seed_values(self, **seed_mapping: int) -> None: for key, value in seed_mapping.items(): if key not in self._random_seeds: raise KeyError(f"'{key}' key is not a random seed key!") - if not isinstance(value, int): - raise TypeError( - f"Random seed value for '{key}' key must be an integer!" - ) - self._random_seeds[key] = value diff --git a/mithril/framework/physical/flat_graph.py b/mithril/framework/physical/flat_graph.py index 469fc9bb..88c3783d 100644 --- a/mithril/framework/physical/flat_graph.py +++ b/mithril/framework/physical/flat_graph.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass from ...core import DataType, GenericDataType @@ -52,7 +52,7 @@ class Connection: target_keys: list[str] connections: set[Connection] - def __hash__(self): + def __hash__(self) -> int: return hash(id(self)) @@ -95,7 +95,7 @@ def __init__(self, input_keys: set[str], output_keys: set[str]) -> None: self._topological_order: list[str] = [] self._input_keys = input_keys - self._random_keys: set[str] = set() + self.random_keys: set[str] = set() self.output_dict: dict[str, str] = {key: key for key in output_keys} self._temp_connection_info: dict[str, str] = {} @@ -104,7 +104,7 @@ def __init__(self, input_keys: set[str], output_keys: set[str]) -> None: self.value_table: dict[str, DataType | ValueType] = {} @property - def hanging_keys(self): + def hanging_keys(self) -> set[str]: hanging_keys = (self.all_target_keys - self.all_source_keys) | set( self.connections.keys() ) - self.all_target_keys - self.all_source_keys @@ -112,30 +112,30 @@ def hanging_keys(self): return hanging_keys - set(self.output_dict.values()) @property - def input_keys(self): + def input_keys(self) -> set[str]: return set(self._input_keys) @property - def output_keys(self): + def output_keys(self) -> set[str]: return set(self.output_dict.keys()) @property - def all_keys(self): + def all_keys(self) -> set[str]: return ( set(self.connections.keys()) | set(self.output_dict.keys()) | set(self.output_dict.values()) ) - def add_value(self, model: PrimitiveModel, keys: dict[str, str]): + def add_value(self, model: PrimitiveModel, keys: dict[str, str]) -> None: output_key = keys[PrimitiveModel.output_key] keys = { key: self._temp_connection_info.get(value, value) for key, value in keys.items() } - if model._random_keys: - self._random_keys |= {keys[key] for key in model._random_keys} + if model.random_keys: + self.random_keys |= {keys[key] for key in model.random_keys} # Buffer primitives are not added to the graph if isinstance(model, Buffer): @@ -182,7 +182,7 @@ def add_value(self, model: PrimitiveModel, keys: dict[str, str]): self._update_all_source_keys() self._update_all_target_keys() - def collapse_model_keys(self, output_key: str, new_reference_key: str): + def collapse_model_keys(self, output_key: str, new_reference_key: str) -> None: # If a model removed, the models that uses the output of the removed model # should be updated with the new reference key. for key, value in self._temp_connection_info.items(): @@ -201,7 +201,7 @@ def update_output_keys(self, output_key: str, new_reference_key: str) -> bool: return True @property - def topological_order(self): + def topological_order(self) -> list[str]: return self._topological_order @property @@ -212,15 +212,13 @@ def all_target_keys(self) -> set[str]: def all_source_keys(self) -> set[str]: return self._all_source_keys - def _update_topological_order(self): + def _update_topological_order(self) -> None: self._topological_order = [ node.connections[PrimitiveModel.output_key].key for node in self.nodes.values() - if node.model is not None - or node.connections[PrimitiveModel.output_key].key in self.output_keys ] - def _update_all_source_keys(self): + def _update_all_source_keys(self) -> None: self._all_source_keys = { conn.key for item in self.nodes.values() @@ -228,7 +226,7 @@ def _update_all_source_keys(self): if key != "output" } - def _update_all_target_keys(self): + def _update_all_target_keys(self) -> None: self._all_target_keys = { conn.key for item in self.nodes.values() @@ -236,7 +234,7 @@ def _update_all_target_keys(self): if key == "output" } - def _update_connection_keys(self, connection: Connection): + def _update_connection_keys(self, connection: Connection) -> None: source_keys: list[str] = [] target_keys: list[str] = [] @@ -247,7 +245,7 @@ def _update_connection_keys(self, connection: Connection): key = conn.key source_keys.append(key) - def get_target_keys(connection: Connection): + def get_target_keys(connection: Connection) -> list[str]: target_keys: list[str] = [] for conn in connection.connections: target_keys.append(conn.key) @@ -275,27 +273,27 @@ def get_target_keys(connection: Connection): connection.target_keys = list(target_keys) connection.source_keys = list(source_keys) - def get_model(self, key) -> PrimitiveModel: + def get_model(self, key: str) -> PrimitiveModel: conn = self.connections.get(key, None) if conn is None or conn.node is None: raise ValueError(f"Model not found for key: {key}") return conn.node.model - def get_model_out_key(self, model: PrimitiveModel): + def get_model_out_key(self, model: PrimitiveModel) -> str | None: node = self.nodes.get(model, None) if node is None: return None return node.connections[PrimitiveModel.output_key].key - def get_model_outer_key(self, model: PrimitiveModel, inner_key: str): + def get_model_outer_key(self, model: PrimitiveModel, inner_key: str) -> str: return self.nodes[model].connections[inner_key].key - def get_model_connections(self, model: PrimitiveModel): + def get_model_connections(self, model: PrimitiveModel): # type: ignore return self.nodes[model].connections.values() - def get_connection(self, key: str): - return self.connections.get(key, None) + def get_connection(self, key: str) -> Connection | None: + return self.connections.get(key) def get_source_keys(self, key: str, include_outputs: bool = False) -> list[str]: source_keys: list[str] = [] @@ -346,10 +344,7 @@ def _is_duplicate( node: Node, data: dict[str, Tensor | Scalar], constant_keys: Mapping[str, DataType | MainValueType], - ): - if node.model is None: - return - + ) -> Connection | None: # Model id is a unique key for unique operation model_id: list[str] = [] for key, conn in node.connections.items(): @@ -373,7 +368,7 @@ def _is_duplicate( elif self.is_tensor_type(ref_value) and self.is_tensor_type(value): is_equal = ( id(ref_value) == id(value) - or ref_value.shape == value.shape + or ref_value.shape == value.shape # type: ignore and (ref_value == value).all().item() # type: ignore ) else: @@ -396,8 +391,9 @@ def _is_duplicate( return self.unique_model_table[final_model_id] self.unique_model_table[final_model_id] = node.connections["output"] + return None - def _prune_node(self, node: Node, conn: Connection): + def _prune_node(self, node: Node, conn: Connection) -> None: self.collapse_model_keys(node.connections["output"].key, conn.key) # Update source and target keys of node connections @@ -428,15 +424,14 @@ def _prune_node(self, node: Node, conn: Connection): ) not in self.output_keys and key in self._all_target_keys: self._all_target_keys.remove(key) - if node.model is not None: - self.nodes.pop(node.model) + self.nodes.pop(node.model) self._update_connection_keys(conn) self._update_all_source_keys() self._update_all_target_keys() self._update_topological_order() - def _remove_node(self, node: Node): + def _remove_node(self, node: Node) -> None: connections = set(node.connections.values()) output_conn = node.connections[PrimitiveModel.output_key] @@ -448,12 +443,11 @@ def _remove_node(self, node: Node): self._update_connection_keys(conn) self._remove_conn(output_conn) - if node.model is not None: - self.nodes.pop(node.model) + self.nodes.pop(node.model) self._update_topological_order() - def _remove_conn(self, conn: Connection): + def _remove_conn(self, conn: Connection) -> None: self.connections.pop(conn.key, None) # Remove connection from other connections @@ -462,13 +456,13 @@ def _remove_conn(self, conn: Connection): if conn.key in conn_.target_keys: conn_.target_keys.remove(conn.key) - if conn.key in self._all_source_keys: # and conn.key not in self.alias_map: + if conn.key in self._all_source_keys: self._all_source_keys.remove(conn.key) - if conn.key in self._all_target_keys: # and conn.key not in self.alias_map: + if conn.key in self._all_target_keys: self._all_target_keys.remove(conn.key) - def remove_key(self, key: str): + def remove_key(self, key: str) -> None: if key in self.output_dict: self.output_dict.pop(key) @@ -479,7 +473,7 @@ def remove_key(self, key: str): def infer_ignore_step( self, key: str, keys: set[str], queue: set[str], from_source: bool - ): + ) -> None: forward_key_fn: Callable[[str, bool], list[str]] if from_source: forward_key_fn = self.get_target_keys @@ -501,5 +495,5 @@ def infer_ignore_step( keys.add(value) queue.add(value) - def get_models(self): + def get_models(self) -> Iterable[PrimitiveModel]: return self.nodes.keys() diff --git a/mithril/framework/physical/model.py b/mithril/framework/physical/model.py index 9b191356..b3eca48a 100644 --- a/mithril/framework/physical/model.py +++ b/mithril/framework/physical/model.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import random import warnings @@ -27,6 +29,7 @@ TBD, Connection, ConnectionData, + ConnectionType, DataEvalType, EvaluateAllType, EvaluateGradientsType, @@ -37,12 +40,12 @@ NotAvailable, ParamsEvalType, Scalar, + ShapeResultType, Table, Tensor, UniadicRecord, Updates, Variadic, - _ShapesType, create_shape_map, get_shapes, get_summary, @@ -96,7 +99,7 @@ def __init__( # TODO: Remove wrapping with Model in the future. model = deepcopy(model) extend_info = model() - model_keys = {} + model_keys: dict[str, ConnectionType] = {} for key in model.external_keys: value = extend_info.connections.get(key, NOT_GIVEN) # NOTE: Do not set default value if it is given in constant_keys. @@ -139,9 +142,10 @@ def __init__( ].name key_origin = model.canonical_output.metadata.key_origin if key_origin != current_name: - while key_origin in flat_model.assigned_edges: + while key_origin in flat_model.assigned_names: key_origin = f"_{key_origin}" + assert key_origin is not None self._output_keys.add(key_origin) flat_model.rename_key(current_name, key_origin) @@ -228,17 +232,19 @@ def __init__( if self.backend.backend_type == "numpy": cache_name = "_".join([mappings[output], p_model.cache_name]) mappings["cache"] = cache_name - cache_value: dict | None = None if self.inference else dict() + cache_value: DataEvalType[DataType] | None = ( + None if self.inference else dict() + ) # Create A object for caches in manualgrad backend. cache_scalar = Scalar(dict | None, cache_value) self.data_store.update_data({cache_name: cache_scalar}) self.flat_graph.add_value(p_model, mappings) - self.data_store.set_random_seed_keys(self.flat_graph._random_keys) + self.data_store.set_random_seed_keys(self.flat_graph.random_keys) for cached_key in list(self.data_store.cached_data.keys()): - self.data_store._infer_unused_keys(cached_key) + self.data_store.infer_unused_keys(cached_key) # First part of the pm with all the inferences. self._pre_compile( @@ -270,9 +276,9 @@ def __init__( def __call__( self, - params: dict[str, DataType] | None = None, - data: Mapping[str, DataType | MainValueType] | None = None, - ): + params: ParamsEvalType[DataType] | None = None, + data: DataEvalType[DataType] | None = None, + ) -> DataEvalType[DataType]: return self.evaluate(params=params, data=data) @property @@ -368,7 +374,7 @@ def get_shapes( var_keys: dict[Variadic, str] | None = None, symbolic: bool = False, verbose: bool = False, - ) -> _ShapesType: + ) -> ShapeResultType: if model is not None: # Find corresponding data from self.data_store_data_memo. data_dict = { @@ -390,22 +396,24 @@ def get_shapes( ) @property - def data(self): - return self.data_store._all_data + def data(self) -> dict[str, Tensor | Scalar]: + return self.data_store.all_data @property - def shapes(self) -> _ShapesType: + def shapes(self) -> ShapeResultType: return self.get_shapes() @property - def output_keys(self): + def output_keys(self) -> list[str]: return sorted(self._output_keys) @property - def input_keys(self): + def input_keys(self) -> set[str]: return self._input_keys - def _infer_differentiability(self, model: PrimitiveModel, dag: dict[str, str]): + def _infer_differentiability( + self, model: PrimitiveModel, dag: dict[str, str] + ) -> None: # Infer output differentiability only for the models # that have a Tensor type output. if isinstance(model.output.metadata.data, Tensor): @@ -500,7 +508,7 @@ def _pre_compile( constant_keys: dict[str, DataType | MainValueType], data_keys: set[str], shapes: PhysicalShapeType, - ): + ) -> None: self.ignore_grad_keys: set[str] = set() # Set given shapes. @@ -533,16 +541,16 @@ def _pre_compile( logical_id = reverse_data_memo[pruned_data] self.data_store.data_memo[logical_id] = remained_data - updates |= remained_data.match(pruned_data) + updates |= remained_data.match(pruned_data) # type: ignore self.data[key] = remained_data - for value in self.data_store._intermediate_non_differentiables.inverse: + for value in self.data_store.intermediate_non_differentiables.inverse: # there can exist some inferred intermediate scalar keys in logical model. # find those keys and add to cached datas if isinstance(value, Scalar) and value.value is not TBD: updates.add(value) - self.data_store._update_cached_data(updates) + self.data_store.update_cached_data(updates) self.data_store.constraint_solver(updates) @@ -661,9 +669,9 @@ def infer_ignore( def _calculate_parameters( self, - name_mappings: dict[Model, str], + name_mappings: dict[BaseModel, str], data_to_key_map: dict[Tensor | Scalar, list[str]] | None = None, - ): + ) -> tuple[dict[str, tuple[dict[str, str], dict[str, str]]], str]: total_params: int = 0 seen_data: set[Tensor] = set() exact_param_status: bool = True @@ -735,7 +743,7 @@ def _print_model_info( total_params: str, data_to_key_map: dict[Tensor | Scalar, list[str]], model: BaseModel | None = None, - ): + ) -> None: # Find constant inputs of the model. pm_constant_input_keys = ( self._input_keys - self.data_store.unused_keys @@ -803,7 +811,7 @@ def summary( alternative_shapes: bool = False, print_info: bool = True, name: str | None = None, - ): + ) -> None: uni_keys: dict[UniadicRecord, str] = dict() var_keys: dict[Variadic, str] = dict() if model is None and depth != 0: @@ -823,13 +831,9 @@ def summary( type_info = None # Extract all summary information - dag: list[PrimitiveModel] | dict[BaseModel, dict[str, ConnectionData]] + dag: list[BaseModel] | dict[BaseModel, dict[str, ConnectionData]] if model is not None: - if isinstance(model, PrimitiveModel): - dag = [model] - elif isinstance(model, Model): - dag = model.dag - + dag = model.dag if isinstance(model, Model) else [model] name_mappings = define_unique_names(dag) conn_info = model.extract_connection_info( name_mappings, data_to_key_map, self.data_store.data_memo @@ -844,9 +848,9 @@ def summary( all_models.remove(unused_model.node.model) name_mappings = define_unique_names(all_models) - conn_info = self.extract_connection_info(name_mappings) + conn_info = self.extract_connection_info(name_mappings) # type: ignore - model_shapes: dict[str, _ShapesType] = { + model_shapes: dict[str, ShapeResultType] = { sub_model_name: self.get_shapes( sub_model, uni_keys, var_keys, symbolic, alternative_shapes ) @@ -855,7 +859,8 @@ def summary( # calculate all key parameters and total parameters param_info, total_parameters = self._calculate_parameters( - name_mappings, data_to_key_map + name_mappings, + data_to_key_map, ) if print_info: @@ -901,7 +906,7 @@ def summary( def extract_connection_info( self, name_mappings: dict[PrimitiveModel, str] | None = None - ): + ) -> dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]]: if name_mappings is None: name_mappings = define_unique_names(self.flat_graph.get_models()) conn_info: dict[str, tuple[dict[str, list[str]], dict[str, list[str]]]] = {} @@ -964,7 +969,7 @@ def extract_connection_info( def set_random_seed_values(self, **seed_mapping: int) -> None: self.data_store.set_random_seed_values(**seed_mapping) - def _step_random_seed_values(self): + def _step_random_seed_values(self) -> None: for key, value in self.data_store._random_seeds.items(): random.seed(value) new_seed = random.randint(0, 2**14) @@ -1007,7 +1012,7 @@ def _replace_with_primitive( ) primitive.parent = model.parent - p_key_mappings = {} + p_key_mappings: dict[str, str] = {} # for key in model._input_keys | model.output_keys: for key in model.external_keys: if key[0] != "$": @@ -1085,17 +1090,17 @@ class Name: name: str origin: str - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Name): return self.name == other.name if isinstance(other, str): return self.name == other return False - def startswith(self, prefix: str): + def startswith(self, prefix: str) -> bool: return self.name.startswith(prefix) @@ -1151,7 +1156,7 @@ def external_keys(self) -> set[str]: """ return set(self.external_mapping.values()) - def rename_key(self, source_name: str, target_name: str): + def rename_key(self, source_name: str, target_name: str) -> None: """ Rename a key from source_name to target_name. @@ -1168,7 +1173,7 @@ def rename_key(self, source_name: str, target_name: str): self._update_defined_names(source_name, target_name) - def _update_defined_names(self, old_key: str, new_key: str): + def _update_defined_names(self, old_key: str, new_key: str) -> None: old_name = self.assigned_names[old_key] if old_name.origin in self.key_origins: if self.key_origins[old_name.origin] == 0: @@ -1185,7 +1190,7 @@ def _update_defined_names(self, old_key: str, new_key: str): for key, value in self._external_mapping.items() } - def _name_externals(self): + def _name_externals(self) -> None: external_keys = list(self.model.conns.input_keys) + list( self.model.conns.output_keys ) @@ -1263,7 +1268,7 @@ def generate_keys( model: BaseModel, mappings: dict[str, str] | None = None, parent_name: str = "", - ): + ) -> None: """ Generate keys for the model. @@ -1289,7 +1294,7 @@ def generate_keys( def _process_primitive_model( self, model: PrimitiveModel, mappings: dict[str, str], parent_name: str - ): + ) -> None: """ Process a primitive model. @@ -1323,7 +1328,9 @@ def _process_primitive_model( self.used_edges.add(output_edge) self._check_for_queue(output_edge) - def _process_model(self, model: Model, mappings: dict[str, str], parent_name: str): + def _process_model( + self, model: Model, mappings: dict[str, str], parent_name: str + ) -> None: submodel_names = model.get_unique_submodel_names() for m, value in model.dag.items(): @@ -1355,7 +1362,7 @@ def _process_model(self, model: Model, mappings: dict[str, str], parent_name: st self.generate_keys(m, name_mapping, parent_name=name) - def _check_for_queue(self, hyperedge: IOHyperEdge): + def _check_for_queue(self, hyperedge: IOHyperEdge) -> None: if hyperedge in self.queued_models: for m, mappings, parent_name in self.queued_models[hyperedge]: if self._is_primitive_ready(m): @@ -1363,7 +1370,7 @@ def _check_for_queue(self, hyperedge: IOHyperEdge): m, mappings=mappings, parent_name=parent_name ) - def _is_primitive_ready(self, model: PrimitiveModel): + def _is_primitive_ready(self, model: PrimitiveModel) -> bool: """ Check if a primitive model is ready to be processed. @@ -1381,7 +1388,7 @@ def _is_primitive_ready(self, model: PrimitiveModel): def _add_primitive_to_queue( self, model: PrimitiveModel, mappings: dict[str, str], parent_name: str - ): + ) -> None: """ Add a primitive model to the queue. @@ -1432,7 +1439,7 @@ def _create_name(self, name: str, key_origin: str) -> Name: self.assigned_names[name] = new_name return new_name - def _rebase_names(self): + def _rebase_names(self) -> None: """ Rebase the names to remove unnecessary suffixes. """ @@ -1444,10 +1451,10 @@ def _rebase_names(self): self.assigned_names[name].name = base_name self.assigned_names[base_name] = self.assigned_names.pop(name) - def __iter__(self): + def __iter__(self) -> FlatModel: self._iter = iter(self.mappings.items()) return self - def __next__(self): + def __next__(self) -> tuple[PrimitiveModel, dict[str, str]]: model, mapping = next(self._iter) return model, {key: name.name for key, name in mapping.items()} diff --git a/mithril/framework/utils.py b/mithril/framework/utils.py index 5d3055d5..6a259eb5 100644 --- a/mithril/framework/utils.py +++ b/mithril/framework/utils.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from functools import reduce from itertools import product from types import FunctionType, GenericAlias, UnionType -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar + +if TYPE_CHECKING: + from .logical.base import BaseModel class NestedListType: @@ -33,11 +37,14 @@ def __init__(self, base_type: type | UnionType): self.base_type = base_type -def define_unique_names(models): +T = TypeVar("T", bound="BaseModel") + + +def define_unique_names(models: Iterable[T]) -> dict[T, str]: # TODO: Move this to Physical model (currently it is only used there) # TODO: Also add short-naming logic to this function - model_name_dict = {} - single_model_dict = {} + model_name_dict: dict[T, str] = {} + single_model_dict: dict[str, T] = {} model_count_dict: dict[str, int] = {} for model in models: diff --git a/mithril/models/models.py b/mithril/models/models.py index c7a60fd3..10e8b428 100644 --- a/mithril/models/models.py +++ b/mithril/models/models.py @@ -1271,7 +1271,7 @@ def __init__( self._freeze() - def __call__( # type: ignore[override] + def __call__( self, input: ConnectionType = NOT_GIVEN, output: ConnectionType = NOT_GIVEN, @@ -1751,7 +1751,7 @@ def __init__( super().__init__(name=name) # self.set_values(**kwargs) - def __call__(self, **kwargs) -> ExtendInfo: # type: ignore[override] + def __call__(self, **kwargs: ConnectionType) -> ExtendInfo: raise NotImplementedError("__call__ method not implemented!") @@ -1832,7 +1832,7 @@ def __init__( prev_cell = current_cell self._freeze() - def __call__( # type: ignore[override] + def __call__( self, input: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super(RNN, self).__call__(input=input, **model_keys) @@ -1967,7 +1967,7 @@ def __init__( self._freeze() - def __call__( # type: ignore[override] + def __call__( self, hidden_concat: ConnectionType = NOT_GIVEN, **model_keys: ConnectionType ) -> ExtendInfo: return super(RNN, self).__call__(hidden_concat=hidden_concat, **model_keys) diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 83c1a0b9..faf3dfc9 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -1880,7 +1880,7 @@ def __init__( key=BaseKey(type=int, value=key), ) - self._random_keys.add("key") + self.random_keys.add("key") self.set_constraint(randn_constraints, keys=["output", "shape"]) def __call__( # type: ignore[override] diff --git a/mithril/models/train_model.py b/mithril/models/train_model.py index deab9392..ca903f65 100644 --- a/mithril/models/train_model.py +++ b/mithril/models/train_model.py @@ -17,23 +17,23 @@ from copy import deepcopy from typing import Any, Self -from ..framework import ( +from ..framework import BaseModel, ExtendInfo, Model +from ..framework.common import ( NOT_GIVEN, - BaseModel, + TBD, Connection, ConnectionData, ConnectionType, - ExtendInfo, IOHyperEdge, IOKey, KeyType, - Model, + NotAvailable, + Table, UniadicRecord, Variadic, get_shapes, get_summary_shapes, ) -from ..framework.common import TBD, NotAvailable, Table from ..framework.logical import ( Buffer, Divide, @@ -531,7 +531,7 @@ def summary( if isinstance(self._model, Model): summary_kwargs["depth"] = depth - self._model.summary(**summary_kwargs) # type: ignore + self._model.summary(**summary_kwargs) name_mappings = self.get_unique_submodel_names() conn_info = self.extract_connection_info(name_mappings) diff --git a/mithril/utils/dict_conversions.py b/mithril/utils/dict_conversions.py index 6ecf608c..9f98eb10 100644 --- a/mithril/utils/dict_conversions.py +++ b/mithril/utils/dict_conversions.py @@ -350,7 +350,7 @@ def train_model_to_dict(context: TrainModel) -> dict: return context_dict -def dict_to_trainmodel(context_dict: dict): +def dict_to_trainmodel(context_dict: dict) -> BaseModel: model = dict_to_model(context_dict["model"]) assert isinstance(model, Model), "TrainModel requires a Model object!" diff --git a/mypy.ini b/mypy.ini index dfb40220..7fe2e636 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,4 +1,20 @@ [mypy] -check_untyped_defs = True -enable_incomplete_feature = NewGenericSyntax -ignore_missing_imports = True +strict = False +warn_return_any = False +ignore_missing_imports = True + + + +[mypy-mithril.framework.physical.*] +warn_unused_configs = True +disallow_any_generics = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True +warn_unused_ignores = True +warn_return_any = False +no_implicit_reexport = True +ignore_missing_imports = True \ No newline at end of file diff --git a/releases/generate_changelog.py b/releases/generate_changelog.py index b9becbca..3a51c92a 100644 --- a/releases/generate_changelog.py +++ b/releases/generate_changelog.py @@ -34,7 +34,7 @@ import sys from contextlib import redirect_stdout -from github import Github # type: ignore +from github import Github def print_pulls(repo_name, title, pulls): diff --git a/tests/scripts/test_constraints.py b/tests/scripts/test_constraints.py index 4ec925d1..03e3e0f5 100644 --- a/tests/scripts/test_constraints.py +++ b/tests/scripts/test_constraints.py @@ -28,6 +28,7 @@ PossibleValues, Scalar, ShapeRepr, + ShapeResultType, ShapeTemplateType, Tensor, ToBeDetermined, @@ -36,7 +37,6 @@ Updates, UpdateType, Variadic, - _ShapesType, create_shape_repr, ) from mithril.framework.constraints import ( @@ -80,14 +80,16 @@ def is_type_checker( - ref_results: dict[str, type | NestedListType] | _ShapesType, constraint_fn: Callable + ref_results: dict[str, type | NestedListType] | ShapeResultType, + constraint_fn: Callable, ) -> TypeGuard[dict[str, type | NestedListType]]: return constraint_fn in type_constraints def is_shape_checker( - ref_results: dict[str, type | NestedListType] | _ShapesType, constraint_fn: Callable -) -> TypeGuard[_ShapesType]: + ref_results: dict[str, type | NestedListType] | ShapeResultType, + constraint_fn: Callable, +) -> TypeGuard[ShapeResultType]: return constraint_fn not in type_constraints @@ -197,7 +199,7 @@ def extract_variadic_possibles( def assert_shape_results( data: dict[str, Tensor | Scalar], - ref_results: _ShapesType, + ref_results: ShapeResultType, ref_assignments: AssignmentType, updated_symbols: Updates, expected_updates: set[str], @@ -267,7 +269,7 @@ def assert_value_results( def make_assertions( constraint_fn: Callable, data: dict[str, Tensor | Scalar], - ref_results: dict[str, type | NestedListType] | _ShapesType, + ref_results: dict[str, type | NestedListType] | ShapeResultType, ref_assignments: AssignmentType, updated_symbols: Updates, expected_updates: set[str], diff --git a/tests/scripts/test_data_store.py b/tests/scripts/test_data_store.py index 51f4512a..6be5e394 100644 --- a/tests/scripts/test_data_store.py +++ b/tests/scripts/test_data_store.py @@ -60,7 +60,7 @@ def test_data_store_1(): assert pm.data_store.data_values.keys() == {"input"} assert (pm.data_store.data_values[key].value == value).all() # type: ignore [union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -93,7 +93,7 @@ def test_data_store_1_numpy(): } assert (pm.data_store.data_values[key].value == value).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -110,7 +110,7 @@ def test_data_store_3(): assert pm.data_store.data_values.keys() == {"output_1"} assert (pm.data_store.data_values["output_1"] == backend.array(6.0)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "input", "weight", @@ -202,7 +202,7 @@ def test_data_store_7(): assert pm.data_store.data_values.keys() == {"input"} assert (res["output"] == value).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -218,7 +218,7 @@ def test_data_store_8(): assert pm.data_store.data_values.keys() == {"output1"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"input"} @@ -235,7 +235,7 @@ def test_data_store_9(): assert pm.data_store.data_values.keys() == {"output1"} assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"input"} @@ -252,7 +252,7 @@ def test_data_store_10(): assert pm.data_store.data_values.keys() == {"input", "output2"} assert (pm.data_store.data_values["output2"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == set() @@ -269,7 +269,7 @@ def test_data_store_11(): assert (pm.data_store.data_values["output1"] == backend.sigmoid(value)).all() # type: ignore[union-attr] assert (pm.data_store.data_values["output3"] == backend.sigmoid(value) + 2).all() # type: ignore[union-attr] assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "right", "input", @@ -294,7 +294,7 @@ def test_data_store_13(): assert pm.data_store.data_values.keys() == {"out"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == {"left", "right"} infered_value = pm.data_store.data_values["out"] @@ -328,7 +328,7 @@ def test_data_store_14(): ) assert pm.data_store.data_values.keys() == {"input1", "out2"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "output_0", @@ -383,7 +383,7 @@ def test_data_store_15(): ) assert pm.data_store.data_values.keys() == {"input1", "out2"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table == dict() + assert pm.data_store.intermediate_non_differentiables._table == dict() assert pm.data_store.unused_keys == { "output_6", @@ -437,7 +437,7 @@ def test_data_store_16(): "output_cache", } assert pm.data_store.runtime_static_keys == {"input"} - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -464,7 +464,7 @@ def test_data_store_17(): assert pm.data_store.data_values.keys() == {"output_0_cache", "output_cache"} assert pm.data_store.runtime_static_keys == {"right"} - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -494,7 +494,7 @@ def test_data_store_18(): assert pm.data_store.data_values.keys() == set() assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -524,7 +524,7 @@ def test_data_store_19(): assert pm.data_store.data_values.keys() == set() assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == set() @@ -554,5 +554,5 @@ def test_data_store_20(): assert pm.data_store.data_values.keys() == {"tensor_out"} assert pm.data_store.runtime_static_keys == set() - assert pm.data_store._intermediate_non_differentiables._table.keys() == set() + assert pm.data_store.intermediate_non_differentiables._table.keys() == set() assert pm.data_store.unused_keys == {"left", "output_1"} diff --git a/tests/scripts/test_flatmodel.py b/tests/scripts/test_flatmodel.py index a8214113..4e2d97b0 100644 --- a/tests/scripts/test_flatmodel.py +++ b/tests/scripts/test_flatmodel.py @@ -397,7 +397,7 @@ def test_integration_collision_from_different_levels(): pm_short = ml.compile(model, backend) pm_long = ml.compile(model, backend, use_short_namings=False) - input_short = {"d": backend.array([1, 2, 3]), "e_1": backend.array([4, 5, 6])} + input_short = {"d": backend.array([1, 2, 3]), "e": backend.array([4, 5, 6])} input_long = { "middle_d": backend.array([1, 2, 3]), "middle_e": backend.array([4, 5, 6]), @@ -406,7 +406,7 @@ def test_integration_collision_from_different_levels(): res_short = pm_short.evaluate(input_short) res_long = pm_long.evaluate(input_long) - expected_res = {"e": backend.array([5, 7, 9], dtype=ml.int64)} + expected_res = {"_e": backend.array([5, 7, 9], dtype=ml.int64)} - np.testing.assert_allclose(expected_res["e"], res_short["e"]) # type: ignore - np.testing.assert_allclose(expected_res["e"], res_long["e"]) # type: ignore + np.testing.assert_allclose(expected_res["_e"], res_short["_e"]) # type: ignore + np.testing.assert_allclose(expected_res["_e"], res_long["e"]) # type: ignore diff --git a/tests/scripts/test_pm.py b/tests/scripts/test_pm.py index c2df8900..f5ad8221 100644 --- a/tests/scripts/test_pm.py +++ b/tests/scripts/test_pm.py @@ -19,7 +19,7 @@ from mithril.models import Model, Randn -def test_random_keys_not_provided(): +def testrandom_keys_not_provided(): example_model = Model() example_model += Randn()(shape=(3, 4, 5, 6), output=ml.IOKey("out1")) example_model += Randn()(shape=(3, 4, 5, 1), output=ml.IOKey("out2")) @@ -38,7 +38,7 @@ def test_random_keys_not_provided(): np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, res1, res2) -def test_random_keys_some_of_provided(): +def testrandom_keys_some_of_provided(): example_model = Model() # Static inference will infer this function example_model += Randn(key=42)(shape=(3, 4, 5, 6), output=ml.IOKey("out1")) @@ -58,7 +58,7 @@ def test_random_keys_some_of_provided(): np.testing.assert_array_equal(res1, res2) -def test_set_random_keys(): +def test_setrandom_keys(): example_model = Model() # Static inference will infer this function example_model += Randn(key=42)(shape=(3, 4, 5, 6), output=ml.IOKey("out1")) diff --git a/tests/utils.py b/tests/utils.py index 24b65b0d..2dfaf897 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -93,8 +93,8 @@ def check_physical_models( assert backend.all(value.value == pm_2.data[key].value) # type: ignore assert pm_1.data_store.cached_data.keys() == pm_2.data_store.cached_data.keys() assert ( - pm_1.data_store._intermediate_non_differentiables._table.keys() - == pm_2.data_store._intermediate_non_differentiables._table.keys() + pm_1.data_store.intermediate_non_differentiables._table.keys() + == pm_2.data_store.intermediate_non_differentiables._table.keys() ) assert ( pm_1.data_store.runtime_static_keys == pm_2.data_store.runtime_static_keys