diff --git a/mithril/backends/backend.py b/mithril/backends/backend.py index 18fb312e..2fe89aeb 100644 --- a/mithril/backends/backend.py +++ b/mithril/backends/backend.py @@ -36,7 +36,6 @@ class Backend(ABC, Generic[DataType]): device_type = None is_installed = True _device: Any - _precision: int supported_dtypes = [ core.Dtype.float16, core.Dtype.bfloat16, @@ -58,7 +57,7 @@ def __init__(self, dtype: core.Dtype = core.float32, device: str = "cpu") -> Non @property def precision(self) -> int: - return self._precision + raise NotImplementedError("Precision is not implemented!") #!! @property diff --git a/mithril/backends/with_autograd/jax_backend/backend.py b/mithril/backends/with_autograd/jax_backend/backend.py index 840f16e1..55bcdb55 100644 --- a/mithril/backends/with_autograd/jax_backend/backend.py +++ b/mithril/backends/with_autograd/jax_backend/backend.py @@ -57,7 +57,6 @@ def __init__( self._device = device utils.get_device(device) # Check device is available self._dtype = dtype - self._precision = DtypeBits[dtype.name].value self._parallel_manager: JaxParallel | None = None super().__init__(dtype=dtype, device_mesh=device_mesh) @@ -83,6 +82,10 @@ def inf(self) -> float: def nan(self) -> float: return jax.numpy.nan + @property + def precision(self) -> int: + return DtypeBits[self._dtype.name].value + def get_backend_array_type(self) -> type[jax.Array]: return jax.Array @@ -172,7 +175,7 @@ def array( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> jax.Array: - _dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision) + _dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision) with jax.default_device(self.device): array = jax.numpy.array(input, dtype=utils.dtype_map[_dtype]) diff --git a/mithril/backends/with_autograd/mlx_backend/backend.py b/mithril/backends/with_autograd/mlx_backend/backend.py index 7f447af5..01c8792a 100644 --- a/mithril/backends/with_autograd/mlx_backend/backend.py +++ b/mithril/backends/with_autograd/mlx_backend/backend.py @@ -44,7 +44,6 @@ def __init__( os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" self._dtype = dtype - self._precision = DtypeBits[dtype.name].value self._device = device super().__init__(dtype=dtype) @@ -68,6 +67,10 @@ def nan(self) -> float: def device(self) -> Any: utils.get_device(self._device) + @property + def precision(self) -> int: + return DtypeBits[self._dtype.name].value + def get_device(self) -> Any: return self._device @@ -177,7 +180,7 @@ def _handle_sequence_type_fun( return [output] def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array: - _dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision) + _dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision) return mx.array(input, dtype=utils.dtype_map[_dtype]) def zeros( diff --git a/mithril/backends/with_autograd/torch_backend/backend.py b/mithril/backends/with_autograd/torch_backend/backend.py index 2492ee04..851006f0 100644 --- a/mithril/backends/with_autograd/torch_backend/backend.py +++ b/mithril/backends/with_autograd/torch_backend/backend.py @@ -58,7 +58,6 @@ def __init__( ) -> None: self._device = device self._dtype = dtype - self._precision = DtypeBits[dtype.name].value self._parallel_manager: TorchParallel | None = None utils.get_device(device) # Check if device is valid @@ -92,6 +91,10 @@ def DataType(self) -> type[torch.Tensor]: # noqa: N802 def device(self) -> torch.device: return utils.get_device(self._device) + @property + def precision(self) -> int: + return DtypeBits[self._dtype.name].value + def get_backend_array_type(self) -> type[torch.Tensor]: return torch.Tensor @@ -207,7 +210,7 @@ def array( dtype: Dtype | None = None, device_mesh: tuple[int, ...] | None = None, ) -> torch.Tensor: - _dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision) + _dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision) array = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device) if self._parallel_manager is not None: diff --git a/mithril/backends/with_manualgrad/c_backend/backend.py b/mithril/backends/with_manualgrad/c_backend/backend.py index bc215780..ea1b1a74 100644 --- a/mithril/backends/with_manualgrad/c_backend/backend.py +++ b/mithril/backends/with_manualgrad/c_backend/backend.py @@ -30,7 +30,6 @@ class CBackend(Backend[PyArray]): SRC_PATH = "mithril/backends/with_manualgrad/c_backend/src" def __init__(self) -> None: - self._precision = 32 self._device = "cpu" self.primitive_function_dict = {} @@ -38,6 +37,10 @@ def __init__(self) -> None: def is_manualgrad(self) -> bool: return True + @property + def precision(self) -> int: + return 32 + def set_seed(self, seed: int) -> None: pass diff --git a/mithril/backends/with_manualgrad/numpy_backend/backend.py b/mithril/backends/with_manualgrad/numpy_backend/backend.py index b0207883..ad88eb6d 100644 --- a/mithril/backends/with_manualgrad/numpy_backend/backend.py +++ b/mithril/backends/with_manualgrad/numpy_backend/backend.py @@ -46,7 +46,6 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]): def __init__(self, device: str = "cpu", dtype: Dtype = Dtype.float32) -> None: self._dtype = dtype - self._precision = DtypeBits[dtype.name].value if device != "cpu": raise RuntimeError( @@ -78,6 +77,10 @@ def nan(self) -> float: def DataType(self) -> type[np.ndarray[Any, Any]]: # noqa: N802 return utils.ArrayType + @property + def precision(self) -> int: + return DtypeBits[self._dtype.name].value + def get_backend_array_type(self) -> type[np.ndarray[Any, Any]]: return np.ndarray @@ -118,7 +121,7 @@ def accumulate_grads( return utils.accumulate_grads(gradient, input, cache, idx) def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]: - _dtype = utils.determine_dtype(data, dtype, self._dtype, self._precision) + _dtype = utils.determine_dtype(data, dtype, self._dtype, self.precision) return np.array(data, dtype=utils.dtype_map[_dtype])