Skip to content

Commit

Permalink
review updates
Browse files Browse the repository at this point in the history
  • Loading branch information
aturker-synnada committed Jan 20, 2025
1 parent e53a922 commit b80c64b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 61 deletions.
10 changes: 5 additions & 5 deletions mithril/framework/codegen/numpy_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np

from ...backends.with_manualgrad.numpy_backend import NumpyBackend
from ...core import Dtype
from ...framework.physical.model import PhysicalModel
from ...framework.utils import find_intersection_type
from ...utils.func_utils import is_make_array_required, prepare_function_args
Expand Down Expand Up @@ -175,11 +174,12 @@ def evaluate_gradients_wrapper_manualgrad(
out_data = params[_key]
else:
out_data = _key_cache["output"]
# dtype = getattr(self.backend, f"float{self.backend.precision}")

assert isinstance(out_data, np.ndarray)
# dtype = getattr(Dtype, f"float{self.backend.precision}")
dtype = Dtype[f"float{self.backend.precision}"]
gradients[key] = self.backend.zeros_like(out_data, dtype=dtype)

gradients[key] = self.backend.zeros_like(
out_data, dtype=self.backend._dtype
)

if output_gradients is None:
if FinalCost not in self.pm._output_keys:
Expand Down
70 changes: 14 additions & 56 deletions tests/scripts/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,6 @@ def compile_and_compare(
# Primitive Model Tests


def test_jax():
arr = [1.0, 2.0, 3.0]
backends = [
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.bfloat16),
]
for backend in backends:
backend.array(arr)


def test_buffer_1():
model = Buffer()
model.set_types(input=Tensor)
Expand Down Expand Up @@ -2477,12 +2465,14 @@ def test_cast_int16():
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down Expand Up @@ -2524,12 +2514,14 @@ def test_cast_int32():
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[Backend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down Expand Up @@ -2570,12 +2562,14 @@ def test_cast_int64():
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[Backend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down Expand Up @@ -2614,12 +2608,14 @@ def test_cast_float16():
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down Expand Up @@ -2653,62 +2649,20 @@ def test_cast_float16():
np.testing.assert_allclose(res, reference_outputs["output"]) # type: ignore


# def test_cast_bfloat16():
# model = Cast(dtype=mithril.bfloat16)
# inp_int = np.array([1, -2, 3], dtype=np.int32)
# inp_float = np.array([1, -2, 3], dtype=np.float32)
# backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [
# TorchBackend(dtype=mithril.float16),
# TorchBackend(dtype=mithril.bfloat16),
# TorchBackend(dtype=mithril.float32),
# TorchBackend(dtype=mithril.float64),
# JaxBackend(dtype=mithril.float16),
# JaxBackend(dtype=mithril.bfloat16),
# JaxBackend(dtype=mithril.float32),
# JaxBackend(dtype=mithril.float64),
# ]

# if platform.system() == "Darwin":
# backends += [
# MlxBackend(dtype=mithril.float16),
# MlxBackend(dtype=mithril.bfloat16),
# MlxBackend(),
# ]

# expected_dtypes = {
# "torch": torch.bfloat16,
# "jax": jax.numpy.bfloat16,
# "mlx": mx.bfloat16,
# }

# statics = {"inp_int": inp_int, "inp_float": inp_float}

# for backend in backends:
# for static in statics.values():
# _static = backend.array(static)
# pm = mithril.compile(
# model,
# backend, # type: ignore
# constant_keys={"input": _static},
# inference=True,
# )
# res = pm.evaluate()["output"]
# assert isinstance(res, backend.DataType)
# assert res.dtype == expected_dtypes[backend.backend_type]


def test_cast_float32():
model = Cast(dtype=mithril.float32)
inp_int = np.array([1, -2, 3], dtype=np.int32)
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[Backend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down Expand Up @@ -2749,12 +2703,14 @@ def test_cast_float64():
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[Backend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down Expand Up @@ -2791,12 +2747,14 @@ def test_cast_bool():
inp_float = np.array([1, -2, 3], dtype=np.float32)
backends: list[Backend] = [
TorchBackend(dtype=mithril.float16),
TorchBackend(dtype=mithril.bfloat16),
TorchBackend(dtype=mithril.float32),
TorchBackend(dtype=mithril.float64),
NumpyBackend(dtype=mithril.float16),
NumpyBackend(dtype=mithril.float32),
NumpyBackend(dtype=mithril.float64),
JaxBackend(dtype=mithril.float16),
JaxBackend(dtype=mithril.bfloat16),
JaxBackend(dtype=mithril.float32),
JaxBackend(dtype=mithril.float64),
]
Expand Down

0 comments on commit b80c64b

Please sign in to comment.