Skip to content

Commit

Permalink
precision is now property
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Jan 20, 2025
1 parent b80c64b commit 19ae15a
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 11 deletions.
3 changes: 1 addition & 2 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ class Backend(ABC, Generic[DataType]):
device_type = None
is_installed = True
_device: Any
_precision: int
supported_dtypes = [
core.Dtype.float16,
core.Dtype.bfloat16,
Expand All @@ -58,7 +57,7 @@ def __init__(self, dtype: core.Dtype = core.float32, device: str = "cpu") -> Non

@property
def precision(self) -> int:
return self._precision
raise NotImplementedError("Precision is not implemented!")

#!!
@property
Expand Down
7 changes: 5 additions & 2 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def __init__(
self._device = device
utils.get_device(device) # Check device is available
self._dtype = dtype
self._precision = DtypeBits[dtype.name].value
self._parallel_manager: JaxParallel | None = None

super().__init__(dtype=dtype, device_mesh=device_mesh)
Expand All @@ -83,6 +82,10 @@ def inf(self) -> float:
def nan(self) -> float:
return jax.numpy.nan

@property
def precision(self) -> int:
return DtypeBits[self._dtype.name].value

def get_backend_array_type(self) -> type[jax.Array]:
return jax.Array

Expand Down Expand Up @@ -172,7 +175,7 @@ def array(
dtype: Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> jax.Array:
_dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision)

with jax.default_device(self.device):
array = jax.numpy.array(input, dtype=utils.dtype_map[_dtype])
Expand Down
7 changes: 5 additions & 2 deletions mithril/backends/with_autograd/mlx_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

self._dtype = dtype
self._precision = DtypeBits[dtype.name].value
self._device = device
super().__init__(dtype=dtype)

Expand All @@ -68,6 +67,10 @@ def nan(self) -> float:
def device(self) -> Any:
utils.get_device(self._device)

@property
def precision(self) -> int:
return DtypeBits[self._dtype.name].value

def get_device(self) -> Any:
return self._device

Expand Down Expand Up @@ -177,7 +180,7 @@ def _handle_sequence_type_fun(
return [output]

def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array:
_dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision)
return mx.array(input, dtype=utils.dtype_map[_dtype])

def zeros(
Expand Down
7 changes: 5 additions & 2 deletions mithril/backends/with_autograd/torch_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
) -> None:
self._device = device
self._dtype = dtype
self._precision = DtypeBits[dtype.name].value
self._parallel_manager: TorchParallel | None = None

utils.get_device(device) # Check if device is valid
Expand Down Expand Up @@ -92,6 +91,10 @@ def DataType(self) -> type[torch.Tensor]: # noqa: N802
def device(self) -> torch.device:
return utils.get_device(self._device)

@property
def precision(self) -> int:
return DtypeBits[self._dtype.name].value

def get_backend_array_type(self) -> type[torch.Tensor]:
return torch.Tensor

Expand Down Expand Up @@ -207,7 +210,7 @@ def array(
dtype: Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> torch.Tensor:
_dtype = utils.determine_dtype(input, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(input, dtype, self._dtype, self.precision)

array = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device)
if self._parallel_manager is not None:
Expand Down
5 changes: 4 additions & 1 deletion mithril/backends/with_manualgrad/c_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ class CBackend(Backend[PyArray]):
SRC_PATH = "mithril/backends/with_manualgrad/c_backend/src"

def __init__(self) -> None:
self._precision = 32
self._device = "cpu"
self.primitive_function_dict = {}

@property
def is_manualgrad(self) -> bool:
return True

@property
def precision(self) -> int:
return 32

def set_seed(self, seed: int) -> None:
pass

Expand Down
7 changes: 5 additions & 2 deletions mithril/backends/with_manualgrad/numpy_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class NumpyBackend(Backend[np.ndarray[Any, Any]]):

def __init__(self, device: str = "cpu", dtype: Dtype = Dtype.float32) -> None:
self._dtype = dtype
self._precision = DtypeBits[dtype.name].value

if device != "cpu":
raise RuntimeError(
Expand Down Expand Up @@ -78,6 +77,10 @@ def nan(self) -> float:
def DataType(self) -> type[np.ndarray[Any, Any]]: # noqa: N802
return utils.ArrayType

@property
def precision(self) -> int:
return DtypeBits[self._dtype.name].value

def get_backend_array_type(self) -> type[np.ndarray[Any, Any]]:
return np.ndarray

Expand Down Expand Up @@ -118,7 +121,7 @@ def accumulate_grads(
return utils.accumulate_grads(gradient, input, cache, idx)

def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]:
_dtype = utils.determine_dtype(data, dtype, self._dtype, self._precision)
_dtype = utils.determine_dtype(data, dtype, self._dtype, self.precision)

return np.array(data, dtype=utils.dtype_map[_dtype])

Expand Down

0 comments on commit 19ae15a

Please sign in to comment.