Skip to content

Commit

Permalink
chore: Fix type errors backend files (#92)
Browse files Browse the repository at this point in the history
Co-authored-by: kberat-synnada <97015093+kberat-synnada@users.noreply.github.com>
  • Loading branch information
mehmetozsoy-synnada and kberat-synnada authored Dec 24, 2024
1 parent 46983c3 commit e645f2a
Show file tree
Hide file tree
Showing 38 changed files with 1,363 additions and 1,168 deletions.
2 changes: 1 addition & 1 deletion benchmarks/speed_benchmarks/speed_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def measure_time_and_grads_mithril(

grads = model.evaluate_gradients(trainable_params, data)

if model.backend.type == "mlx":
if model.backend.backend_type == "mlx":
mx.eval(grads)
trainable_params = {
key: value - lr * grads[key] for key, value in trainable_params.items()
Expand Down
3 changes: 2 additions & 1 deletion examples/gpt/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import warnings
from collections.abc import Callable
from typing import Any

import tiktoken
from model import create_gpt
Expand Down Expand Up @@ -119,7 +120,7 @@ def get_weights(backend: Backend):


def generate(
model: PhysicalModel,
model: PhysicalModel[Any],
block_size: int,
weights: dict[str, ml.DataType],
idx: ml.DataType,
Expand Down
64 changes: 33 additions & 31 deletions mithril/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
class Backend(ABC, Generic[DataType]):
"""Base class for backend implementations in the Mithril library."""

type = ""
backend_type = ""
device_type = None
supported_precisions = [16, 32, 64]
is_installed = True
_device: str
_device: Any
_precision: int
primitive_function_dict: dict[str, Callable[..., DataType | Any]]
registered_primitives: dict[str, Callable[..., DataType]]
Expand Down Expand Up @@ -65,7 +65,7 @@ def device(self):
return self._device

@property
def inf(self):
def inf(self) -> DataType | float:
raise NotImplementedError("inf is not implemented")

@property
Expand All @@ -80,29 +80,34 @@ def e(self):
def is_manualgrad(self) -> bool:
raise NotImplementedError("is_manualgrad is not implemented")

def get_backend_array_type(self): # noqa: B902
def get_backend_array_type(self) -> type[DataType]:
raise NotImplementedError("get_backend_array_type is not implemented")

@staticmethod
def register_primitive(fn: Callable) -> None:
def register_primitive(fn: Callable[..., Any]) -> None:
raise NotImplementedError("register_primitive is not implemented!")

@abstractmethod
def set_seed(self, seed: int):
def set_seed(self, seed: int) -> None:
raise NotImplementedError(
"set_seed function must be overriden for every backend individually!"
)

def to_device(self, data: DataType, device: str, asynchronous: bool = True):
def to_device(
self, data: DataType, device: str, asynchronous: bool = True
) -> DataType:
raise RuntimeError("Backend does not support to_device method!")

def block_until_ready(self, data: DataType):
def block_until_ready(self, data: DataType) -> DataType | None:
raise RuntimeError("Backend does not support block_until_ready method!")

def empty_cache(self): # noqa: B027
pass
# print("Warning: empty_cache is not supported!")

# TODO: Fix types in cast function when python
# adds Higher-Kinded TypeVar support.
# https://github.com/python/typing/issues/548#issuecomment-1193345123
def cast(self, value: Any) -> Any:
# Simply casts given value to the backend's precision.
# If type of value is not int or float, returns the
Expand Down Expand Up @@ -141,7 +146,7 @@ def arange(
dtype: core.Dtype | None = None,
) -> DataType: ...

def arange(self, *args: int | float, **kwargs) -> DataType:
def arange(self, *args: int | float, **kwargs: Any) -> DataType:
raise NotImplementedError("arange is not implemented!")

def flatten(
Expand Down Expand Up @@ -255,7 +260,7 @@ def isnan(self, input: DataType) -> DataType:
"""
raise NotImplementedError("isnan is not implemented!")

def array(self, data: Any, *, dtype: core.Dtype | None = None) -> DataType:
def array(self, input: Any, *, dtype: core.Dtype | None = None) -> DataType:
"""Returns a backend array on speficied device by copying `data`.
Parameters
Expand Down Expand Up @@ -316,7 +321,7 @@ def ones(
raise NotImplementedError("ones is not implemented!")

def ones_like(
self, array: DataType, *, dtype: core.Dtype | None = None
self, input: DataType, *, dtype: core.Dtype | None = None
) -> DataType:
"""Returns a new backend array filled with ones, with the same size,
same dtype and same device with the given array.
Expand All @@ -337,7 +342,7 @@ def ones_like(
raise NotImplementedError("ones_like is not implemented!")

def zeros_like(
self, array: DataType, *, dtype: core.Dtype | None = None
self, input: DataType, *, dtype: core.Dtype | None = None
) -> DataType:
"""Returns a new backend array filled with zeros, with the same size,
same dtype and same device with the given array.
Expand Down Expand Up @@ -588,7 +593,7 @@ def softplus(self, input: DataType) -> DataType:
"""
raise NotImplementedError("softplus is not implemented!")

def stop_gradient(self, data: DataType) -> DataType:
def stop_gradient(self, input: DataType) -> DataType:
"""
Stop the gradient computation for the given data.
Expand Down Expand Up @@ -677,7 +682,7 @@ def expand_dims(self, input: DataType, axis: int) -> DataType:
"""
raise NotImplementedError("expand_dims is not implemented!")

def stack(self, arrays: list[DataType], axis: int = 0) -> DataType:
def stack(self, inputs: list[DataType], axis: int = 0) -> DataType:
"""
Stack a sequence of arrays along a new axis.
Expand All @@ -693,7 +698,7 @@ def stack(self, arrays: list[DataType], axis: int = 0) -> DataType:
"""
raise NotImplementedError("stack is not implemented!")

def cat(self, arrays: list[DataType], axis: int = 0) -> DataType:
def cat(self, inputs: list[DataType], axis: int = 0) -> DataType:
"""
Concatenate a sequence of arrays along an existing axis.
Expand Down Expand Up @@ -814,12 +819,12 @@ def any(self, input: DataType) -> DataType:
raise NotImplementedError("any is not implemented!")

def transpose(
self, data: DataType, axes: tuple[int, ...] | list[int] | None
self, input: DataType, axes: tuple[int, ...] | list[int] | None
) -> DataType:
raise NotImplementedError()

def unique(
self, input: DataType, **kwargs
self, input: DataType, **kwargs: Any
) -> tuple[DataType, DataType | None, DataType | None]:
raise NotImplementedError("unique is not implemented!")

Expand All @@ -830,15 +835,11 @@ def where(self, cond: DataType, input1: DataType, input2: DataType) -> DataType:
raise NotImplementedError("where is not implemented!")

def multinomial(
self,
probs: DataType,
num_samples: int,
replacement: bool = False,
**kwargs,
self, probs: DataType, num_samples: int, replacement: bool = False
) -> DataType:
raise NotImplementedError("multinomial is not implemented!")

def jit(self, fn: Callable) -> Callable:
def jit[T: Any](self, fn: Callable[..., T]) -> Callable[..., T]:
"""
Just-in-time compile the given function.
Expand Down Expand Up @@ -982,7 +983,7 @@ def vjp(
"""
raise NotImplementedError("vjp is not implemented!")

def vmap(self, fn: Callable) -> Callable:
def vmap[T: Callable[..., Any]](self, fn: T) -> T:
"""
Vectorize the given function.
Expand Down Expand Up @@ -1052,7 +1053,7 @@ def __init__(self, device_mesh: tuple[int, ...] | None) -> None:

self._raw_device_mesh = device_mesh
self.n_devices = math.prod(device_mesh) if device_mesh is not None else 1
self._parallel_manager: Parallel | None
self._parallel_manager: Parallel[DataType] | None

def zeros(
self,
Expand Down Expand Up @@ -1370,22 +1371,23 @@ def linspace(

raise NotImplementedError("linspace is not implemented!")

def _register_callable(
self, fn: Callable | partial, fn_name: str, jit: bool
def _register_callable[T: Any](
self, fn: Callable[..., T] | partial[T], fn_name: str, jit: bool
) -> None:
raise NotImplementedError()

def _run_callable(self, *primals, fn_name: str):
def _run_callable(self, *primals: Any, fn_name: str) -> Any:
raise NotImplementedError()

def _create_parallel(self, device_mesh: tuple[int, ...]) -> Parallel:
def _create_parallel(self, device_mesh: tuple[int, ...]) -> None:
raise NotImplementedError(
f"{self.type.capitalize()} backend does not support parallelization!"
f"{self.backend_type.capitalize()} "
+ "backend does not support parallelization!"
)


class UnavailableBackend:
is_installed = False

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise RuntimeError("Backend is unavailable due to missing dependencies.")
12 changes: 7 additions & 5 deletions mithril/backends/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Generic
from typing import Any, Generic

from ..core import DataType


class Parallel(ABC, Generic[DataType]):
def __init__(self, n_devices) -> None:
def __init__(self, n_devices: int) -> None:
self.n_devices = n_devices
self.callables: dict[str, Callable] = {}
self.callables: dict[str, Callable[..., Any]] = {}

if self.n_devices <= 1:
raise ValueError(
Expand All @@ -31,11 +31,13 @@ def __init__(self, n_devices) -> None:
)

@abstractmethod
def run_callable(self, *primals, fn_name: str):
def run_callable(self, *primals: Any, fn_name: str) -> dict[str, Any]:
raise NotImplementedError()

@abstractmethod
def parallelize(self, tensor: DataType, device_mesh: tuple[int, ...] | None = None):
def parallelize(
self, tensor: DataType, device_mesh: tuple[int, ...] | None = None
) -> dict[str, Any]:
raise NotImplementedError()

def clean_up(self):
Expand Down
44 changes: 29 additions & 15 deletions mithril/backends/with_autograd/common_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,13 @@ def squared_error(input: DataType, target: DataType):
return (input - target) ** 2


def minus(input: DataType):
def minus(input: DataType) -> DataType:
return -input


def transpose(input: DataType, axes: tuple[int, ...] | list[int] | None = None):
def transpose(
input: DataType, axes: tuple[int, ...] | list[int] | None = None
) -> DataType:
if not axes:
return input.T
return input.transpose(*axes)
Expand All @@ -167,11 +169,11 @@ def buffer(input: DataType):
return input


def permute_tensor(input: DataType, indices: DataType):
def permute_tensor(input: DataType, indices: DataType) -> DataType:
return input[indices] # type: ignore


def reshape(input: DataType, shape: tuple[int, ...]):
def reshape(input: DataType, shape: tuple[int, ...]) -> DataType:
return input.reshape(shape)


Expand All @@ -191,7 +193,7 @@ def cartesian_diff(left: DataType, right: DataType):
return left[:, None, :] - right[None, :, :]


def primitive_embedding(input: DataType, weight: DataType):
def primitive_embedding(input: DataType, weight: DataType) -> DataType:
return weight[input] # type: ignore


Expand Down Expand Up @@ -226,7 +228,10 @@ def to_list(*args: tuple[int | float | bool, ...]):
return list(args)


def padding_converter_1d(input, kernel_size):
def padding_converter_1d(
input: PaddingType | int | Sequence[int], kernel_size: tuple[int, int] | int
) -> tuple[int, int]:
output: tuple[int, int]
if isinstance(input, PaddingType):
if input == PaddingType.VALID:
output = (0, 0)
Expand All @@ -243,15 +248,18 @@ def padding_converter_1d(input, kernel_size):
elif isinstance(input, int):
output = (input, input)

elif isinstance(input, Sequence):
else:
if isinstance(input[0], Sequence) or isinstance(input[1], Sequence):
raise RuntimeError(f"Given input '{input}' is not valid!")
output = tuple(input)
output = (input[0], input[1])

return output


def padding_converter_2d(input, kernel_size):
def padding_converter_2d(
input: PaddingType | int | Sequence[int] | Sequence[Sequence[int]],
kernel_size: tuple[int, int] | int,
) -> tuple[int, int] | tuple[tuple[int, int], tuple[int, int]]:
output: tuple[int, int] | tuple[tuple[int, int], tuple[int, int]]
if isinstance(input, PaddingType):
if input == PaddingType.VALID:
Expand All @@ -262,18 +270,16 @@ def padding_converter_2d(input, kernel_size):
"'same' padding is not supported when the kernel size is even!"
)
output = (kernel_size[0] // 2, kernel_size[1] // 2)
elif isinstance(kernel_size, int):
else:
if kernel_size % 2 == 0:
raise RuntimeError(
"'same' padding is not supported when the kernel size is even!"
)
half = kernel_size // 2
output = ((half, half), (half, half))
else:
raise RuntimeError("Kernel size must be 'tuple[int, int]' or 'int'!")
elif isinstance(input, int):
output = (input, input)
elif isinstance(input, Sequence):
else:
if isinstance(input[0], int) and isinstance(input[1], int):
output = (input[0], input[1])
elif isinstance(input[0], Sequence) and isinstance(input[1], Sequence):
Expand All @@ -284,14 +290,22 @@ def padding_converter_2d(input, kernel_size):
return output


def stride_converter(input, kernel_size):
def stride_converter(
input: int | PaddingType | tuple[int, int] | None,
kernel_size: int | tuple[int, int],
):
if input is None:
return kernel_size
else:
return input


def tuple_converter(input):
def tuple_converter(
input: int
| PaddingType
| tuple[int, int]
| tuple[tuple[int, int], tuple[int, int]],
):
if isinstance(input, int):
return (input, input)
else:
Expand Down
Loading

0 comments on commit e645f2a

Please sign in to comment.