From 090136fafde9f0c55813686949e93197362c1be9 Mon Sep 17 00:00:00 2001 From: aturker-synnada Date: Wed, 25 Dec 2024 15:52:44 +0300 Subject: [PATCH] remove creation fn wrapper --- mithril/backends/backend.py | 30 +- .../with_autograd/jax_backend/backend.py | 262 ++++++++---------- .../with_autograd/jax_backend/utils.py | 62 +---- .../with_autograd/mlx_backend/backend.py | 108 +++----- .../with_autograd/mlx_backend/utils.py | 50 +--- .../with_autograd/torch_backend/backend.py | 227 +++++++-------- .../with_autograd/torch_backend/utils.py | 72 +---- .../with_manualgrad/numpy_backend/backend.py | 136 +++------ .../with_manualgrad/numpy_backend/utils.py | 51 +--- 9 files changed, 334 insertions(+), 664 deletions(-) diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index d2562b72..3245bdea 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -146,8 +146,30 @@ def arange( dtype: core.Dtype | None = None, ) -> DataType: ... - def arange(self, *args: int | float, **kwargs: Any) -> DataType: - raise NotImplementedError("arange is not implemented!") + def arange(self, *args: int | float, **kwargs) -> DataType: + """Generate an array of evenly spaced values within a specified range.""" + if len(args) == 0: + raise RuntimeError( + "arange() missing 1 required positional argument: 'stop'" + ) + elif len(args) == 1: + return self._arange(0, args[0], 1, **kwargs) # type: ignore + elif len(args) == 2: + if args[0] >= args[1]: + return self.array([]) + + return self._arange( # type: ignore + args[0], args[1], 1, **kwargs + ) + elif len(args) == 3: + return self._arange( # type: ignore + args[0], args[1], args[2], **kwargs + ) + else: + raise RuntimeError( + "arange() accepts 1 to 3 positional arguments," + " but `f{len(args)}` were provided" + ) def flatten( self, input: DataType, start_dim: int = 0, end_dim: int = -1 @@ -459,7 +481,7 @@ def linspace( self, start: int | float | bool | DataType, stop: int | float | bool | DataType, - steps: int | DataType, + steps: int, dtype: core.Dtype | None = None, ) -> DataType: """ @@ -1349,7 +1371,7 @@ def linspace( self, start: int | float | bool | DataType, stop: int | float | bool | DataType, - steps: int | DataType, + steps: int, dtype: core.Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> DataType: diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index cf173964..8573674a 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -15,7 +15,6 @@ import math import os from collections.abc import Callable, Sequence -from functools import partial from typing import Any, overload import jax @@ -141,67 +140,6 @@ def block_until_ready(self, data: jax.Array) -> jax.Array | None: """ return data.block_until_ready() - def _creation_fn_wrapper( - self, fn: Callable[..., jax.Array] - ) -> Callable[..., jax.Array]: - """ - Wrapper for array creation functions. - - Parameters - ---------- - fn: Callable - The original array creation function. - - Returns - ------- - Callable - A wrapped function that creates arrays with specified dtype and device. - - Notes - ----- - Ensures that arrays are created with the correct dtype and device. - """ - - array_conversion_fn = partial( - utils.creation_fn_wrapper, - fn=fn, - device=self._device, - precision=self.precision, - ) - array_conversion_fn = partial(self._parallelize, fn=array_conversion_fn) - - return array_conversion_fn - - def _parallelize( - self, - *args: Any, - fn: Callable[..., jax.Array], - device_mesh: tuple[int, ...], - **kwargs: Any, - ) -> jax.Array: - """ - Parallelizes the function's return tensor across devices. - - Parameters - ---------- - fn : Callable - The function whose return tensor will be parallelized. - - device_mesh : tuple[int, ...], optional - A tuple specifying the device mesh for parallelization. - If not provided, the default device mesh is used. - - Returns - ------- - Callable - Return tensor parallelized across the specified device mesh. - """ - - tensor: jax.Array = fn(*args, **kwargs) - if self._parallel_manager is None: - return tensor - return self._parallel_manager.parallelize(tensor, device_mesh) - def _register_callable( self, fn: Callable[..., Any], fn_name: str, jit: bool = False ): @@ -232,9 +170,11 @@ def array( ) -> jax.Array: _dtype = utils.determine_dtype(input, dtype, self.precision) - array = jax.numpy.array( - input, dtype=utils.dtype_map[_dtype], device=self.device - ) + with jax.default_device(self.device): + array = jax.numpy.array( + input, dtype=utils.dtype_map[_dtype], device=self.device + ) + if self._parallel_manager is not None: array = self._parallel_manager.parallelize(array, device_mesh) @@ -246,14 +186,16 @@ def zeros( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.numpy.zeros)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.numpy.zeros(_shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def ones( self, @@ -261,14 +203,16 @@ def ones( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.numpy.ones)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.numpy.ones(_shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def ones_like( self, @@ -277,13 +221,15 @@ def ones_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - result = self._creation_fn_wrapper(jax.numpy.ones_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) - return result + _dtype = self._process_dtype(dtype) if dtype is not None else None + + with jax.default_device(self.device): + array = jax.numpy.ones_like(input, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def zeros_like( self, @@ -292,13 +238,15 @@ def zeros_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - result = self._creation_fn_wrapper(jax.numpy.zeros_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) - return result + _dtype = self._process_dtype(dtype) if dtype is not None else None + + with jax.default_device(self.device): + array = jax.numpy.zeros_like(input, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def randn( self, @@ -309,14 +257,17 @@ def randn( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.random.normal)( - prng_key, _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.random.normal(prng_key, _shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def rand( self, @@ -327,14 +278,17 @@ def rand( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.random.uniform)( - prng_key, _shape, dtype=_dtype, device_mesh=device_mesh - ) - return result + + with jax.default_device(self.device): + array = jax.random.normal(prng_key, _shape, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def randint( self, @@ -347,19 +301,17 @@ def randint( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - result = self._creation_fn_wrapper(jax.random.randint)( - prng_key, - _shape, - low, - high, - dtype=_dtype, - device_mesh=device_mesh, - ) - return result + + with jax.default_device(self.device): + array = jax.random.randint(prng_key, _shape, low, high, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def rand_uniform( self, @@ -372,47 +324,56 @@ def rand_uniform( ) -> jax.Array: if prng_key is None: prng_key = self.prng_key - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(jax.random.uniform)( - prng_key, - _shape, - dtype=_dtype, - minval=low, - maxval=high, - device_mesh=device_mesh, - ) + + with jax.default_device(self.device): + array = jax.random.uniform(prng_key, _shape, _dtype, low, high) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def _arange( self, - *args: int | float, + start: int | float, + stop: int | float, + step: int | float, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, **kwargs: Any, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(jax.numpy.arange)( - *args, dtype=_dtype, device_mesh=device_mesh + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int ) + _dtype = self._process_dtype(dtype, default_type) + + with jax.default_device(self.device): + array = jax.numpy.arange(start, stop, step, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def linspace( self, start: int | float | bool | jax.numpy.ndarray, stop: int | float | bool | jax.numpy.ndarray, - steps: int | jax.numpy.ndarray, + steps: int, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype: jax.numpy.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(jax.numpy.linspace)( - start, stop, steps, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) + with jax.default_device(self.device): + array = jax.numpy.linspace(start, stop, steps, dtype=_dtype) + + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize(array, device_mesh) + + return array def flatten( self, input: jax.Array, start_dim: int = 0, end_dim: int = -1 @@ -697,3 +658,18 @@ def jacfwd( self, fn: Callable[..., dict[str, jax.Array]] ) -> Callable[..., dict[str, jax.Array]]: return jax.jacfwd(fn) + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> jax.numpy.dtype[Any]: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") + + def _get_defualt_type(self): + return getattr(self, f"float{self.precision}") diff --git a/mithril/backends/with_autograd/jax_backend/utils.py b/mithril/backends/with_autograd/jax_backend/utils.py index a60b02d9..9c5139b6 100644 --- a/mithril/backends/with_autograd/jax_backend/utils.py +++ b/mithril/backends/with_autograd/jax_backend/utils.py @@ -333,66 +333,6 @@ def _parse_device_string(device: str): return backend, device_idx -def handle_dtype(dtype: str | core.Dtype | jnp.dtype[Any]) -> jnp.dtype[Any]: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, str) and dtype in dtype_map: - return dtype_map[dtype] - else: - try: - return jnp.dtype(dtype) - except TypeError as err: - raise TypeError(f"Provided data type '{dtype}' not understood") from err - - -def creation_fn_wrapper( - *args: Any, - fn: Callable[..., jax.Array], - dtype: core.Dtype | jnp.dtype[Any] | None = None, - device: str, - precision: int, - **kwargs: Any, -): - _device = get_device(device) - - if dtype is not None: - dtype = handle_dtype(dtype) - with jax.default_device(_device): - data = fn(*args, dtype=dtype, **kwargs) - else: - with jax.default_device(_device): - data = fn(*args, **kwargs) - data = handle_data_precision(data, precision) - return data - - -def conversion_fn_wrapper( - data: Any, - *args: Any, - fn: Callable[..., jax.Array], - device: str, - precision: int, - dtype: core.Dtype | jnp.dtype[Any] | None = None, - **kwargs: Any, -): - _device = get_device(device) - - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if next(iter(data.devices())) != _device: - data = jax.device_put(data, _device) - if dtype is not None: - return data.astype(dtype) - return handle_data_precision(data, precision) - else: - with jax.default_device(_device): - _data = fn(data, *args, dtype=dtype, **kwargs) - if dtype is None: # User did not specify dtype explicitly - return handle_data_precision(_data, precision) - return _data - - def handle_data_precision(data: ArrayType, precision: int) -> ArrayType: _dtype = data.dtype # Do not make any changes to boolean types. @@ -515,7 +455,7 @@ def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str dtype_name = "".join( char for char in input.dtype.__str__() if not char.isdigit() ) - elif isinstance(input, np.ndarray) or isinstance(input, np.generic): + elif isinstance(input, (np.ndarray | np.generic)): dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) else: dtype_name = find_dominant_type(input).__name__ diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index 5d4b55ae..ba6ce0ea 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -93,24 +93,6 @@ def to_device( def block_until_ready(self, data: mx.array): mx.eval(data) - def _creation_fn_wrapper( - self, fn: Callable[..., mx.array] - ) -> Callable[..., mx.array]: - return partial( - utils.creation_fn_wrapper, - fn=fn, - precision=self.precision, - ) - - def _conversion_fn_wrapper( - self, fn: Callable[..., mx.array] - ) -> Callable[..., mx.array]: - return partial( - utils.conversion_fn_wrapper, - fn=fn, - precision=self.precision, - ) - def _handle_dict_type_fun( self, *inputs: mx.array, @@ -194,32 +176,28 @@ def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array: def zeros( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.zeros)(shape=_shape, dtype=_dtype) + return mx.zeros(_shape, dtype=_dtype) def ones( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.ones)(shape=_shape, dtype=_dtype) + return mx.ones(_shape, dtype=_dtype) def ones_like(self, input: mx.array, *, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.ones_like)(input, dtype=_dtype) + if dtype is not None: + raise ValueError("dtype argument is not supported for ones_like") + + return mx.ones_like(input) def zeros_like(self, input: mx.array, *, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.zeros_like)(input, dtype=_dtype) + if dtype is not None: + raise ValueError("dtype argument is not supported for ones_like") + + return mx.zeros_like(input) def randn( self, @@ -227,11 +205,9 @@ def randn( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.normal)(shape=_shape, dtype=_dtype) + return mx.random.normal(shape=_shape, dtype=_dtype) def rand( self, @@ -239,11 +215,9 @@ def rand( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.uniform)(shape=_shape, dtype=_dtype) + return mx.random.uniform(shape=_shape, dtype=_dtype) def randint( self, @@ -253,13 +227,9 @@ def randint( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.randint)( - low=low, high=high, shape=_shape, dtype=_dtype - ) + return mx.random.randint(low, high, shape=_shape, dtype=_dtype) def rand_uniform( self, @@ -269,19 +239,23 @@ def rand_uniform( dtype: Dtype | None = None, prng_key: Any = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(mx.random.uniform)( - low=low, high=high, shape=_shape, dtype=_dtype + return mx.random.uniform(low, high, shape=_shape, dtype=_dtype) + + def _arange( + self, + start: int | float, + stop: int | float, + step: int | float, + dtype: Dtype | None = None, + ) -> mx.array: + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int ) + _dtype = self._process_dtype(dtype, default_type) - def arange(self, *args: float | int, dtype: Dtype | None = None) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.arange)(*args, dtype=_dtype) + return mx.arange(start, stop, step, dtype=_dtype) def linspace( self, @@ -290,10 +264,8 @@ def linspace( steps: int | mx.array, dtype: Dtype | None = None, ) -> mx.array: - _dtype: mx.Dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(mx.linspace)(start, stop, steps, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return mx.linspace(start, stop, steps, dtype=_dtype) def flatten( self, input: mx.array, start_dim: int = 0, end_dim: int = -1 @@ -671,3 +643,15 @@ def vmap( # type: ignore #mypy bug self, fn: Callable[[mx.array], mx.array] ) -> Callable[[mx.array], mx.array]: return mx.vmap(fn) + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> mx.Dtype: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") diff --git a/mithril/backends/with_autograd/mlx_backend/utils.py b/mithril/backends/with_autograd/mlx_backend/utils.py index 266d16a3..14d908d9 100644 --- a/mithril/backends/with_autograd/mlx_backend/utils.py +++ b/mithril/backends/with_autograd/mlx_backend/utils.py @@ -52,54 +52,6 @@ def get_device(device: str): return mx.Device(getattr(mx, device), 0) -def creation_fn_wrapper( - *args: Any, - fn: Callable[..., mx.array], - dtype: core.Dtype | mx.Dtype | None = None, - precision: int, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - data = fn(*args, dtype=dtype, **kwargs) - else: - data = fn(*args, **kwargs) - data = handle_data_precision(data, precision) - return data - - -def conversion_fn_wrapper( - data: Any, - *args: Any, - fn: Callable[..., mx.array], - precision: int, - dtype: mx.Dtype | None = None, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if dtype is not None: - return data.astype(dtype) - return handle_data_precision(data, precision) - else: - _data = fn(data, *args, dtype=dtype, **kwargs) - if dtype is None: # User did not specify dtype explicitly - return handle_data_precision(_data, precision) - return _data - - -def handle_dtype(dtype: Any) -> Any: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, str) and dtype in dtype_map: - return dtype_map[dtype] - elif isinstance(dtype, mx.Dtype): - return dtype - else: - raise TypeError(f"Provided data type '{dtype}' not understood") - - def handle_data_precision(data: mx.array, precision: int) -> mx.array: _dtype = data.dtype # Do not make any changes to boolean types. @@ -428,7 +380,7 @@ def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str dtype_name = "".join( char for char in input.dtype.__str__().split(".")[-1] if not char.isdigit() ) - elif isinstance(input, np.ndarray) or isinstance(input, np.generic): + elif isinstance(input, (np.ndarray | np.generic)): dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) else: dtype_name = find_dominant_type(input).__name__ diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index f99125aa..7fd130bb 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Callable, Sequence -from functools import partial from typing import Any, overload import torch @@ -24,7 +23,6 @@ from torch._functorch.eager_transforms import jacfwd as torch_jacfwd from torch._functorch.eager_transforms import jacrev as torch_jacrev from torch._functorch.eager_transforms import vjp as torch_vjp -from torch.distributed._tensor import DTensor from ....core import Dtype from ...backend import PadWidthType, ParallelBackend @@ -139,70 +137,6 @@ def empty_cache(self) -> None: pass # print(f"Warning: empty_cache is not implemented for {self.device_type}") - def _creation_fn_wrapper( - self, fn: Callable[..., torch.Tensor] - ) -> Callable[..., torch.Tensor]: - """ - Wrapper for PyTorch tensor creation functions. - - Parameters - ---------- - fn: Callable - The original tensor creation function. - - Returns - ------- - Callable - A wrapped function that creates tensors with specified dtype and device. - - Notes - ----- - This wrapper ensures that tensors are created with the correct dtype - and on the specified device. - """ - - array_creation_fn = partial( - utils.creation_fn_wrapper_inner, - fn=fn, - device=self._device, - precision=self.precision, - ) - array_creation_fn = partial(self._parallelize, fn=array_creation_fn) - - return array_creation_fn - - def _parallelize( - self, - *args: Any, - fn: Callable[..., torch.Tensor], - device_mesh: tuple[int] | None, - **kwargs: Any, - ) -> DTensor | torch.Tensor: - """ - Parallelizes the function's return tensor across devices. - - Parameters - ---------- - fn : Callable - The function whose return tensor will be parallelized. - device_mesh : tuple[int, ...], optional - A tuple specifying the device mesh for parallelization. - If not provided, the default device mesh is used. - - Returns - ------- - Callable - Returns tensor parallelized across the specified device mesh. - """ - tensor: torch.Tensor = fn(*args, **kwargs) - if self._parallel_manager is None: - # TODO: raise device_mesh should be None - return tensor - - return self._parallel_manager.parallelize( - tensor, self.base_device_mesh, device_mesh - ) - def _register_callable( self, fn: Callable[..., torch.Tensor], fn_name: str, jit: bool = False ): @@ -269,14 +203,13 @@ def array( ) -> torch.Tensor: _dtype = utils.determine_dtype(input, dtype, self.precision) - tensor = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device) - + array = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device) if self._parallel_manager is not None: - tensor = self._parallel_manager.parallelize( - tensor, self.base_device_mesh, device_mesh + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh ) - return tensor + return array def zeros( self, @@ -284,13 +217,16 @@ def zeros( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.zeros)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) + + array = torch.zeros(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + + return array def ones( self, @@ -298,13 +234,15 @@ def ones( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.ones)( - _shape, dtype=_dtype, device_mesh=device_mesh - ) + + array = torch.ones(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def ones_like( self, @@ -313,12 +251,14 @@ def ones_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.ones_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) if dtype is not None else None + + array = torch.ones_like(input, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def zeros_like( self, @@ -327,12 +267,14 @@ def zeros_like( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.zeros_like)( - input, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) if dtype is not None else None + + array = torch.zeros_like(input, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def randn( self, @@ -341,13 +283,16 @@ def randn( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.randn)( - size=_shape, dtype=_dtype, device_mesh=device_mesh - ) + + # TODO: PRNG key is not used + array = torch.randn(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def rand( self, @@ -356,13 +301,15 @@ def rand( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.rand)( - size=_shape, dtype=_dtype, device_mesh=device_mesh - ) + + array = torch.rand(_shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def randint( self, @@ -373,17 +320,15 @@ def randint( device_mesh: tuple[int, ...] | None = None, prng_key: Any = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - return self._creation_fn_wrapper(torch.randint)( - low, - high, - size=_shape, - dtype=_dtype, - device_mesh=device_mesh, - ) + + array = torch.randint(low, high, _shape, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def rand_uniform( self, @@ -400,32 +345,42 @@ def rand_uniform( def _arange( self, - *args: int | float, + start: int | float, + stop: int | float, + step: int | float, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, **kwargs: int | float, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.arange)( - *args, dtype=_dtype, device_mesh=device_mesh + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int ) + _dtype = self._process_dtype(dtype, default_type) + + array = torch.arange(start, stop, step, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + + return array def linspace( self, start: int | float | bool | torch.Tensor, stop: int | float | bool | torch.Tensor, - steps: int | torch.Tensor, + steps: int, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype: torch.dtype | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(torch.linspace)( - start, stop, steps, dtype=_dtype, device_mesh=device_mesh - ) + _dtype = self._process_dtype(dtype) + + array = torch.linspace(start, stop, steps, dtype=_dtype, device=self._device) + if self._parallel_manager is not None: + array = self._parallel_manager.parallelize( + array, self.base_device_mesh, device_mesh + ) + return array def flatten( self, input: torch.Tensor, start_dim: int = 0, end_dim: int = -1 @@ -691,3 +646,15 @@ def jacrev(self, fn: Callable[..., dict[str, torch.Tensor]]) -> Callable: def jacfwd(self, fn: Callable[..., dict[str, torch.Tensor]]) -> Callable: return torch_jacfwd(fn) + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> torch.dtype: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") diff --git a/mithril/backends/with_autograd/torch_backend/utils.py b/mithril/backends/with_autograd/torch_backend/utils.py index 3b49cabb..85b326ed 100644 --- a/mithril/backends/with_autograd/torch_backend/utils.py +++ b/mithril/backends/with_autograd/torch_backend/utils.py @@ -187,76 +187,6 @@ def get_available_devices() -> list[str]: return devices -def handle_dtype(dtype: core.Dtype | torch.dtype | str) -> Any: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, torch.dtype): - return dtype - elif dtype in dtype_map: - return dtype_map[dtype] - raise TypeError(f"Provided data type '{dtype}' not understood") - - -def creation_fn_wrapper_inner( - *args: Any, - dtype: core.Dtype | torch.dtype | str | None = None, - fn: Callable[..., torch.Tensor], - device: str, - precision: int, - device_mesh: tuple[int, ...] | None = None, - **kwargs: Any, -): - _device = get_device(device) - if dtype is not None: - dtype = handle_dtype(dtype) - data = fn(*args, dtype=dtype, device=_device, **kwargs) - else: - data = fn(*args, device=_device, **kwargs) - data = handle_data_precision(data, precision=precision) - - return data - - -def conversion_fn_wrapper_inner( - data: Any, - *args: Any, - dtype: torch.dtype | str | None = None, - fn: Callable[..., torch.Tensor], - device: str, - precision: int, - **kwargs: Any, -) -> torch.Tensor: - _device = get_device(device) - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, torch.Tensor): - if data.device != _device: - data = data.to(_device) - if dtype is not None: - return data.type(dtype) - return handle_data_precision(data, precision=precision) - elif isinstance(data, np.ndarray): - _data = fn(data, *args, dtype=dtype, device=_device, **kwargs) - if ( - dtype is None and _data.dtype != torch.bool - ): # User did not specify dtype explicitly - return handle_data_precision(_data, precision=precision) - return _data - else: - # To determine subtype we are creating tensor twice in worst case - _data = fn(data, *args, dtype=dtype, device=device, **kwargs) - if ( - dtype is None - and get_precision(_data) != precision - and _data.dtype != torch.bool - ): - subtype = get_subtype(_data) - _dtype = getattr(torch, f"{subtype}{precision}") - _data = fn(data, *args, dtype=_dtype, device=device, **kwargs) - return _data - return _data - - def handle_data_precision(data: torch.Tensor, precision: int) -> torch.Tensor: _dtype = data.dtype dtype: torch.dtype @@ -762,7 +692,7 @@ def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str dtype_name = "".join( char for char in input.dtype.__str__().split(".")[1] if not char.isdigit() ) - elif isinstance(input, np.ndarray) or isinstance(input, np.generic): + elif isinstance(input, (np.ndarray | np.generic)): dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) else: dtype_name = find_dominant_type(input).__name__ diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index bcf95afb..da7c91de 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -13,7 +13,6 @@ # limitations under the License. from collections.abc import Callable -from functools import partial from typing import Any import numpy as np @@ -101,52 +100,6 @@ def set_seed(self, seed: int): self.seed = seed np.random.seed(seed) - def _creation_fn_wrapper( - self, fn: Callable[..., np.ndarray[Any, Any]] - ) -> Callable[..., np.ndarray[Any, Any]]: - """ - Wrapper for NumPy array creation functions. - - Parameters - ---------- - fn: Callable - The original array creation function. - - Returns - ------- - Callable - A wrapped function that creates NumPy arrays with specified dtype. - - Notes - ----- - This wrapper ensures that NumPy arrays are created with the correct dtype. - """ - return partial(utils.creation_fn_wrapper, fn=fn, precision=self.precision) - - def _conversion_fn_wrapper( - self, fn: Callable[..., np.ndarray[Any, Any]] - ) -> Callable[..., np.ndarray[Any, Any]]: - """ - Wrapper for NumPy array conversion functions. - - Parameters - ---------- - fn: Callable - The original array conversion function. - - Returns - ------- - Callable - A wrapped function that converts arrays to NumPy arrays with - specified dtype. - - Notes - ----- - This wrapper handles the conversion of arrays to NumPy arrays with - different dtypes. - """ - return partial(utils.conversion_fn_wrapper, fn=fn, precision=self.precision) - def accumulate_grads( self, gradient: np.ndarray[Any, Any], @@ -164,36 +117,28 @@ def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any def zeros( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.zeros)(shape=_shape, dtype=_dtype) + return np.zeros(_shape, dtype=_dtype) def ones( self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.ones)(shape=_shape, dtype=_dtype) + return np.ones(_shape, dtype=_dtype) def ones_like( self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.ones_like)(input, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return np.ones_like(input, dtype=_dtype) def zeros_like( self, input: np.ndarray[Any, Any], *, dtype: Dtype | None = None ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.zeros_like)(input, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return np.zeros_like(input, dtype=_dtype) def randn( self, @@ -201,11 +146,9 @@ def randn( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.randn)(*_shape, dtype=_dtype) + return np.random.randn(*_shape).astype(_dtype) def rand( self, @@ -213,11 +156,9 @@ def rand( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.rand)(*_shape, dtype=_dtype) + return np.random.rand(*_shape).astype(_dtype) def randint( self, @@ -227,13 +168,9 @@ def randint( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype, int) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.randint)( - low=low, high=high, size=_shape, dtype=_dtype - ) + return np.random.randint(low, high, size=_shape).astype(_dtype) def rand_uniform( self, @@ -243,34 +180,33 @@ def rand_uniform( dtype: Dtype | None = None, prng_key: Any = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] + _dtype = self._process_dtype(dtype) _shape = process_shape(shape) - return self._creation_fn_wrapper(np.random.uniform)( - low=low, high=high, size=_shape, dtype=_dtype - ) + return np.random.uniform(low, high, size=_shape).astype(_dtype) - def arange( - self, *args: int | float, dtype: Dtype | None = None + def _arange( + self, + start: int | float, + stop: int | float, + step: int | float, + dtype: Dtype | None = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.arange)(*args, dtype=_dtype) + default_type = ( + float if any(isinstance(x, float) for x in (start, stop, step)) else int + ) + _dtype = self._process_dtype(dtype, default_type) + return np.arange(start, stop, step, dtype=_dtype) def linspace( self, start: int | float | bool | np.ndarray[Any, Any], stop: int | float | bool | np.ndarray[Any, Any], - steps: int | np.ndarray[Any, Any], + steps: int, dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> np.ndarray[Any, Any]: - _dtype: np.dtype[Any] | None = None - if isinstance(dtype, Dtype): - _dtype = utils.dtype_map[dtype.name] - return self._creation_fn_wrapper(np.linspace)(start, stop, steps, dtype=_dtype) + _dtype = self._process_dtype(dtype) + return np.linspace(start, stop, steps, dtype=_dtype) def flatten( self, input: np.ndarray[Any, Any], start_dim: int = 0, end_dim: int = -1 @@ -458,3 +394,15 @@ def multinomial( samples = np.squeeze(samples, axis=0) return samples + + def _process_dtype( + self, + dtype: Dtype | None = None, + default_type: type[float] | type[int] | type[bool] = float, + ) -> np.dtype[Any]: + if isinstance(dtype, Dtype): + return utils.dtype_map[dtype.name] + elif dtype is None: + return utils.dtype_map[default_type.__name__ + str(self.precision)] + else: + raise ValueError(f"Invalid dtype {dtype}") diff --git a/mithril/backends/with_manualgrad/numpy_backend/utils.py b/mithril/backends/with_manualgrad/numpy_backend/utils.py index 528855fe..2ab96a0a 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/utils.py +++ b/mithril/backends/with_manualgrad/numpy_backend/utils.py @@ -308,55 +308,6 @@ def calc_input_slices( return slices -def handle_dtype(dtype: Any) -> Any: - if isinstance(dtype, core.Dtype): - return dtype_map[dtype.name] - elif isinstance(dtype, str) and dtype in dtype_map: - return dtype_map[dtype] - else: - try: - return np.dtype(dtype) - except TypeError as err: - raise TypeError(f"Provided data type '{dtype}' not understood") from err - - -def creation_fn_wrapper( - *args: Any, - fn: Callable[..., np.ndarray[Any, Any]], - precision: int, - dtype: core.Dtype | np.dtype[Any] | None = None, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - data = fn(*args, dtype=dtype, **kwargs) - else: - data = fn(*args, **kwargs) - data = handle_data_precision(data, precision=precision) - return data - - -def conversion_fn_wrapper( - data: Any, - *args: Any, - fn: Callable[..., np.ndarray[Any, Any]], - precision: int, - dtype: np.dtype[Any] | None = None, - **kwargs: Any, -): - if dtype is not None: - dtype = handle_dtype(dtype) - if isinstance(data, ArrayType): - if dtype is not None: - return data.astype(dtype) - return handle_data_precision(data, precision=precision) - else: - _data = fn(data, *args, dtype=dtype, **kwargs) - if dtype is None: - return handle_data_precision(_data, precision=precision) - return _data - - def handle_data_precision( data: np.ndarray[Any, Any], precision: int ) -> np.ndarray[Any, Any]: @@ -501,7 +452,7 @@ def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str if isinstance(dtype, core.Dtype): return dtype.name - if isinstance(input, np.ndarray) or isinstance(input, np.generic): + if isinstance(input, (np.ndarray | np.generic)): dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit()) else: dtype_name = find_dominant_type(input).__name__