Skip to content

Commit

Permalink
feat: Add dtype Parameter to Array Creation Primitives (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada authored Jan 23, 2025
1 parent 06f9445 commit 0726ad4
Show file tree
Hide file tree
Showing 29 changed files with 832 additions and 251 deletions.
8 changes: 8 additions & 0 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def __init__(self, dtype: core.Dtype = core.float32, device: str = "cpu") -> Non
def precision(self) -> int:
return DtypeBits[self._dtype.name].value

@property
def default_dtype(self) -> core.Dtype:
return self._dtype

#!!
@property
def device(self) -> Any:
Expand All @@ -85,6 +89,10 @@ def e(self) -> float:
def is_manualgrad(self) -> bool:
raise NotImplementedError("is_manualgrad is not implemented")

@property
def codegen_config(self) -> dict[str, bool]:
raise NotImplementedError("codegen_config is not implemented")

def get_backend_array_type(self) -> type[DataType]:
raise NotImplementedError("get_backend_array_type is not implemented")

Expand Down
8 changes: 8 additions & 0 deletions mithril/backends/with_autograd/jax_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...utils import DtypeSubTypes, StaticScalar, process_shape
from . import ops, utils
from .parallel import JaxParallel
from .utils import CODEGEN_CONFIG

__all__ = ["JaxBackend"]

Expand Down Expand Up @@ -70,6 +71,9 @@ def __init__(
self.primitive_function_dict = ops.primitive_func_dict
self.prng_key = jax.random.PRNGKey(self.seed)

for key, value in utils.dtype_map.items():
setattr(self, key, value)

@property
def is_manualgrad(self) -> bool:
return False
Expand Down Expand Up @@ -97,6 +101,10 @@ def get_device(self) -> Any:
def DataType(self) -> type[jax.Array]: # noqa: N802
return utils.ArrayType

@property
def codegen_config(self) -> dict[str, bool]:
return CODEGEN_CONFIG

@staticmethod
def get_available_devices() -> list[str]:
"""Static method to get a list of available devices.
Expand Down
110 changes: 83 additions & 27 deletions mithril/backends/with_autograd/jax_backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from collections.abc import Callable, Iterator, Sequence
from functools import partial
from itertools import combinations_with_replacement
Expand All @@ -25,6 +26,7 @@

from .... import core
from ....utils.type_utils import is_tuple_int
from ....utils.utils import find_dominant_type
from ...utils import NestedFloatOrIntOrBoolList
from ..common_primitives import (
add,
Expand Down Expand Up @@ -75,11 +77,9 @@
calculate_binary_class_weight,
calculate_cross_entropy_class_weights,
calculate_tpr_fpr,
dtype_map,
find_optimal_sigmas,
get_device,
get_type,
handle_data_dtype,
handle_data_precision,
log_sigmoid,
log_softmax,
many_to_one_inference_helper,
Expand Down Expand Up @@ -108,6 +108,7 @@
"cos",
"tanh",
"relu",
"cast",
"leaky_relu",
"sigmoid",
"softplus",
Expand Down Expand Up @@ -760,20 +761,6 @@ def kl_divergence(input: jax.Array, target: jax.Array, cutoff: jax.Array) -> jax
return target * (robust_log(target, cutoff) - robust_log(input, cutoff))


def eye(N: int, M: int | None, *, device: str, precision: int) -> jax.Array:
with jax.default_device(get_device(device)):
return handle_data_precision(jnp.eye(N, M), precision)


def ones_with_zero_diag(
N: int, M: int | None, device: str, precision: int
) -> jax.Array:
output = jnp.ones(N) - jnp.eye(N) if M is None else jnp.ones((N, M)) - jnp.eye(N, M)

with jax.default_device(get_device(device)):
return handle_data_precision(output, precision)


def transposed_diag(input: jax.Array) -> jax.Array:
return jnp.diag(input)[:, None]

Expand Down Expand Up @@ -853,20 +840,79 @@ def matrix_concat(input1: jax.Array, input2: jax.Array) -> jax.Array:
return jnp.concatenate((input1, input2), axis=input1.ndim - 1)


### Array creation ops ###


def to_tensor(
input: NestedFloatOrIntOrBoolList, device: str, precision: int
input: NestedFloatOrIntOrBoolList,
*,
dtype: jnp.dtype[Any] | None = None,
device: str,
default_dtype: str,
) -> jax.Array:
dtype_str = default_dtype if dtype is None else dtype_map.inverse[dtype]

dominant_type = find_dominant_type(input)
_dtype = dominant_type.__name__

if _dtype != "bool":
_dtype += str(re.findall(r"\d+", dtype_str)[-1])

with jax.default_device(get_device(device)):
return jnp.array(input, dtype=get_type(input, precision))
return jnp.array(input, dtype=dtype_map[_dtype])


def tensor_to_list(input: jax.Array) -> NestedFloatOrIntOrBoolList:
return input.tolist()
def eye(
N: int,
M: int | None,
*,
dtype: jnp.dtype[Any] | None = None,
device: str,
default_dtype: str,
) -> jax.Array:
dtype = dtype_map[default_dtype] if dtype is None else dtype
with jax.default_device(get_device(device)):
return jnp.eye(N, M, dtype=dtype)


def arange(*args: Any, device: str, precision: int) -> jax.Array:
def ones_with_zero_diag(
N: int,
M: int | None,
*,
dtype: jnp.dtype[Any] | None = None,
device: str,
default_dtype: str,
) -> jax.Array:
dtype = dtype_map[default_dtype] if dtype is None else dtype

with jax.default_device(get_device(device)):
return (
jnp.ones(N, dtype=dtype) - jnp.eye(N, dtype=dtype)
if M is None
else jnp.ones((N, M), dtype=dtype) - jnp.eye(N, M, dtype=dtype)
)


def arange(
start: int | float,
stop: int | float,
step: int | float,
*,
dtype: jnp.dtype[Any] | None = None,
device: str,
default_dtype: str,
) -> jax.Array:
_dtype = default_dtype if dtype is None else dtype_map.inverse[dtype]

if len([item for item in [start, stop, step] if isinstance(item, float)]) == 0:
_dtype = _dtype.replace("float", "int").replace("bfloat", "int")

with jax.default_device(get_device(device)):
return handle_data_precision(jnp.arange(*args), precision)
return jnp.arange(start, stop, step, dtype=dtype_map[_dtype])


def tensor_to_list(input: jax.Array) -> NestedFloatOrIntOrBoolList:
return input.tolist()


def minimum(left: jax.Array, right: jax.Array) -> jax.Array:
Expand Down Expand Up @@ -997,8 +1043,8 @@ def nan_to_num(
return jnp.nan_to_num(input, nan=nan, posinf=posinf, neginf=neginf) # type: ignore


def astype(input: jax.Array, dtype: core.Dtype | int) -> jax.Array:
return handle_data_dtype(input, dtype)
def cast(input: jax.Array, dtype: jnp.dtype[Any]) -> jax.Array:
return input.astype(dtype)


def dtype(input: jax.Array) -> core.Dtype:
Expand All @@ -1017,10 +1063,20 @@ def pad(input: jax.Array, pad_width: tuple[tuple[int, int], ...]) -> jax.Array:
return jax.numpy.pad(input, pad_width)


def randn(shape: tuple[int, ...], key: int, device: str, precision: int) -> jax.Array:
def randn(
shape: tuple[int, ...],
key: int,
*,
dtype: str | None = None,
device: str,
default_dtype: str,
) -> jax.Array:
_key = jax.random.PRNGKey(key)
if dtype is None:
dtype = default_dtype

with jax.default_device(get_device(device)):
return handle_data_precision(jax.random.normal(_key, shape), precision)
return jax.random.normal(_key, shape, dtype=dtype_map[dtype])


def zeros_like(input: jax.Array) -> jax.Array:
Expand Down
34 changes: 17 additions & 17 deletions mithril/backends/with_autograd/jax_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,27 @@
from jax import vmap

from .... import core
from ....utils.utils import binary_search, find_dominant_type
from ....utils.utils import BiMap, binary_search, find_dominant_type
from ...utils import DtypeSubTypes

CODEGEN_CONFIG: dict[str, bool] = {
"specify_device": True,
}

ArrayType = jax.Array

dtype_map: dict[str, jnp.dtype[Any]] = {
"uint8": jnp.uint8,
"int8": jnp.int8,
"int16": jnp.int16,
"int32": jnp.int32,
"int": jnp.int32,
"int64": jnp.int64,
"long": jnp.int64,
"float16": jnp.float16,
"bfloat16": jnp.bfloat16,
"float32": jnp.float32,
"float": jnp.float32,
"float64": jnp.float64,
"double": jnp.float64,
"bool": jnp.bool_,
}
dtype_map: BiMap[str, jnp.dtype[Any]] = BiMap(
{
"int16": jnp.int16,
"int32": jnp.int32,
"int64": jnp.int64,
"float16": jnp.float16,
"bfloat16": jnp.bfloat16,
"float32": jnp.float32,
"float64": jnp.float64,
"bool": jnp.bool_,
}
)


def broadcast_to_highest(
Expand Down
8 changes: 8 additions & 0 deletions mithril/backends/with_autograd/mlx_backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...backend import Backend, PadWidthType
from ...utils import DtypeSubTypes, StaticScalar, process_shape
from . import ops, utils
from .utils import CODEGEN_CONFIG

__all__ = ["MlxBackend"]

Expand Down Expand Up @@ -51,6 +52,9 @@ def __init__(
self.primitive_function_dict = ops.primitive_func_dict
self.prng_key = mx.random.key(self.seed)

for key, value in utils.dtype_map.items():
setattr(self, key, value)

@property
def is_manualgrad(self) -> bool:
return False
Expand All @@ -67,6 +71,10 @@ def nan(self) -> float:
def device(self) -> Any:
utils.get_device(self._device)

@property
def codegen_config(self) -> dict[str, bool]:
return CODEGEN_CONFIG

def get_device(self) -> Any:
return self._device

Expand Down
Loading

0 comments on commit 0726ad4

Please sign in to comment.