Skip to content

Commit

Permalink
Remove convertion fn wrapepr
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Dec 25, 2024
1 parent a513e10 commit 0b89d65
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 92 deletions.
13 changes: 13 additions & 0 deletions mithril/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from collections.abc import Sequence
from typing import Any

from ..utils.type_utils import is_tuple_int

Expand All @@ -36,3 +37,15 @@ def process_shape(
)

return _shape


def determine_shape(lst: Any) -> tuple[int, ...]:
if isinstance(lst, list):
if not lst:
return (0,)
first_shape = determine_shape(lst[0])
for item in lst:
if determine_shape(item) != first_shape:
raise ValueError("Inhomogeneous shapes detected in the list.")
return (len(lst),) + first_shape
return ()
48 changes: 8 additions & 40 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,40 +172,6 @@ def _creation_fn_wrapper(

return array_conversion_fn

def _conversion_fn_wrapper(
self, fn: Callable[..., jax.Array]
) -> Callable[..., jax.Array]:
"""
Wrapper for array conversion functions.
Parameters
----------
fn: Callable
The original array conversion function.
Returns
-------
Callable
A wrapped function that converts arrays with specified dtype and device.
Notes
-----
Handles the conversion of arrays between different dtypes and devices.
If dtype is provided, it uses `utils._handle_dtype` to ensure a valid dtype.
If the input data is a JAX Array, it ensures it's on the specified device.
If dtype is not provided, uses the default device and handles data precision.
"""
array_conversion_fn = partial(
utils.conversion_fn_wrapper,
fn=fn,
device=self._device,
precision=self.precision,
)
array_conversion_fn = partial(self._parallelize, fn=array_conversion_fn)

return array_conversion_fn

def _parallelize(
self,
*args: Any,
Expand Down Expand Up @@ -264,13 +230,15 @@ def array(
dtype: Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> jax.Array:
_dtype: jax.numpy.dtype[Any] | None = None
if isinstance(dtype, Dtype):
_dtype = utils.dtype_map[dtype.name]
result = self._conversion_fn_wrapper(jax.numpy.array)(
input, dtype=_dtype, device_mesh=device_mesh
_dtype = utils.determine_dtype(input, dtype, self.precision)

array = jax.numpy.array(
input, dtype=utils.dtype_map[_dtype], device=self.device
)
return result
if self._parallel_manager is not None:
array = self._parallel_manager.parallelize(array, device_mesh)

return array

def zeros(
self,
Expand Down
17 changes: 17 additions & 0 deletions mithril/backends/with_autograd/jax_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import jax
import jax.numpy as jnp
import numpy as np
from jax import vmap

from .... import core
Expand Down Expand Up @@ -504,3 +505,19 @@ def calculate_cross_entropy_class_weights(
shape[1] = input.shape[1]
_weights = _weights.reshape(shape)
return _weights


def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str:
if isinstance(dtype, core.Dtype):
return dtype.name

if isinstance(input, jax.Array):
dtype_name = "".join(
char for char in input.dtype.__str__() if not char.isdigit()
)
elif isinstance(input, np.ndarray) or isinstance(input, np.generic):
dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit())
else:
dtype_name = find_dominant_type(input).__name__

return dtype_name + str(precision) if dtype_name != "bool" else "bool"
6 changes: 2 additions & 4 deletions mithril/backends/with_autograd/mlx_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,8 @@ def _handle_sequence_type_fun(
return [output]

def array(self, input: Any, *, dtype: Dtype | None = None) -> mx.array:
_dtype: mx.Dtype | None = None
if isinstance(dtype, Dtype):
_dtype = utils.dtype_map[dtype.name]
return self._conversion_fn_wrapper(mx.array)(input, dtype=_dtype)
_dtype = utils.determine_dtype(input, dtype, self.precision)
return mx.array(input, dtype=utils.dtype_map[_dtype])

def zeros(
self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None
Expand Down
16 changes: 16 additions & 0 deletions mithril/backends/with_autograd/mlx_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,22 @@ def get_submatrices2d(
)


def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str:
if isinstance(dtype, core.Dtype):
return dtype.name

if isinstance(input, mx.array):
dtype_name = "".join(
char for char in input.dtype.__str__().split(".")[-1] if not char.isdigit()
)
elif isinstance(input, np.ndarray) or isinstance(input, np.generic):
dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit())
else:
dtype_name = find_dominant_type(input).__name__

return dtype_name + str(precision) if dtype_name != "bool" else "bool"


def get_type(input: int | float | bool | Sequence[Any], precision: int) -> mx.Dtype:
type = find_dominant_type(input).__name__
if type == "bool":
Expand Down
47 changes: 10 additions & 37 deletions mithril/backends/with_autograd/torch_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,37 +171,6 @@ def _creation_fn_wrapper(

return array_creation_fn

def _conversion_fn_wrapper(
self, fn: Callable[..., torch.Tensor]
) -> Callable[..., torch.Tensor]:
"""
Wrapper for PyTorch tensor conversion functions.
Parameters
----------
fn: Callable
The original tensor conversion function.
Returns
-------
Callable
A wrapped function that converts tensors with specified dtype and device.
Notes
-----
Wrapper handles the conversion of tensors between different dtypes and devices.
"""

array_conversion_fn = partial(
utils.conversion_fn_wrapper_inner,
fn=fn,
device=self._device,
precision=self.precision,
)
array_conversion_fn = partial(self._parallelize, fn=array_conversion_fn)

return array_conversion_fn

def _parallelize(
self,
*args: Any,
Expand Down Expand Up @@ -298,12 +267,16 @@ def array(
dtype: Dtype | None = None,
device_mesh: tuple[int, ...] | None = None,
) -> torch.Tensor:
_dtype: torch.dtype | None = None
if isinstance(dtype, Dtype):
_dtype = utils.dtype_map[dtype.name]
return self._conversion_fn_wrapper(torch.tensor)(
input, dtype=_dtype, device_mesh=device_mesh
)
_dtype = utils.determine_dtype(input, dtype, self.precision)

tensor = torch.tensor(input, dtype=utils.dtype_map[_dtype], device=self._device)

if self._parallel_manager is not None:
tensor = self._parallel_manager.parallelize(
tensor, self.base_device_mesh, device_mesh
)

return tensor

def zeros(
self,
Expand Down
26 changes: 20 additions & 6 deletions mithril/backends/with_autograd/torch_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
AVAILABLE_BACKEND_TYPES = ["cpu", "cuda"]

ArrayType = torch.Tensor
NestedTensorType = int | float | bool | Sequence["NestedTensorType"]
dtype_map: dict[str, torch.dtype] = {
"int16": torch.int16,
"int32": torch.int32,
Expand Down Expand Up @@ -228,7 +229,7 @@ def conversion_fn_wrapper_inner(
_device = get_device(device)
if dtype is not None:
dtype = handle_dtype(dtype)
if isinstance(data, ArrayType):
if isinstance(data, torch.Tensor):
if data.device != _device:
data = data.to(_device)
if dtype is not None:
Expand Down Expand Up @@ -256,7 +257,7 @@ def conversion_fn_wrapper_inner(
return _data


def handle_data_precision(data: ArrayType, precision: int) -> ArrayType:
def handle_data_precision(data: torch.Tensor, precision: int) -> torch.Tensor:
_dtype = data.dtype
dtype: torch.dtype
# Do not make any changes to boolean types.
Expand All @@ -276,7 +277,7 @@ def handle_data_precision(data: ArrayType, precision: int) -> ArrayType:
return data


def handle_data_dtype(data: ArrayType, dtype: core.Dtype | int) -> ArrayType:
def handle_data_dtype(data: torch.Tensor, dtype: core.Dtype | int) -> torch.Tensor:
dtype = core.Dtype(dtype)

if data.dtype != dtype_map[dtype.name]:
Expand All @@ -286,11 +287,11 @@ def handle_data_dtype(data: ArrayType, dtype: core.Dtype | int) -> ArrayType:
return data


def get_precision(data: ArrayType) -> int:
def get_precision(data: torch.Tensor) -> int:
return data.dtype.itemsize * 8


def get_subtype(data: ArrayType) -> str:
def get_subtype(data: torch.Tensor) -> str:
# TODO: cover uint dtypes
if not torch.is_floating_point(data) and not torch.is_complex(data):
return "int"
Expand Down Expand Up @@ -753,7 +754,20 @@ def check_device_mesh(base_mesh: DeviceMesh, device_mesh: tuple[int, ...]):
)


NestedTensorType = int | float | bool | Sequence["NestedTensorType"]
def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str:
if isinstance(dtype, core.Dtype):
return dtype.name

if isinstance(input, torch.Tensor):
dtype_name = "".join(
char for char in input.dtype.__str__().split(".")[1] if not char.isdigit()
)
elif isinstance(input, np.ndarray) or isinstance(input, np.generic):
dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit())
else:
dtype_name = find_dominant_type(input).__name__

return dtype_name + str(precision) if dtype_name != "bool" else "bool"


def get_type(input: NestedTensorType, precision: int):
Expand Down
9 changes: 4 additions & 5 deletions mithril/backends/with_manualgrad/numpy_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,10 @@ def accumulate_grads(
) -> np.ndarray[Any, Any]:
return utils.accumulate_grads(gradient, input, cache, idx)

def array(self, input: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]:
_dtype: np.dtype[Any] | None = None
if isinstance(dtype, Dtype):
_dtype = utils.dtype_map[dtype.name]
return self._conversion_fn_wrapper(np.array)(input, dtype=_dtype)
def array(self, data: Any, *, dtype: Dtype | None = None) -> np.ndarray[Any, Any]:
_dtype = utils.determine_dtype(data, dtype, self.precision)

return np.array(data, dtype=utils.dtype_map[_dtype])

def zeros(
self, *shape: int | tuple[int, ...] | list[int], dtype: Dtype | None = None
Expand Down
14 changes: 14 additions & 0 deletions mithril/backends/with_manualgrad/numpy_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,20 @@ def calculate_cross_entropy_class_weights(
return _weights


def determine_dtype(input: Any, dtype: core.Dtype | None, precision: int) -> str:
if isinstance(dtype, core.Dtype):
return dtype.name

if isinstance(input, np.ndarray) or isinstance(input, np.generic):
dtype_name = "".join(char for char in str(input.dtype) if not char.isdigit())
else:
dtype_name = find_dominant_type(input).__name__

print(type(input), dtype, dtype_name)

return dtype_name + str(precision) if dtype_name != "bool" else "bool"


def get_type(
input: int | float | bool | Sequence[int | float | bool | Sequence[Any]],
precision: int,
Expand Down

0 comments on commit 0b89d65

Please sign in to comment.