From b80c64b74fe4e69285c3a7c2b11f964857fdab08 Mon Sep 17 00:00:00 2001 From: aturker-synnada Date: Mon, 20 Jan 2025 13:47:16 +0300 Subject: [PATCH] review updates --- mithril/framework/codegen/numpy_gen.py | 10 ++-- tests/scripts/test_all_models.py | 70 ++++++-------------------- 2 files changed, 19 insertions(+), 61 deletions(-) diff --git a/mithril/framework/codegen/numpy_gen.py b/mithril/framework/codegen/numpy_gen.py index fbad975f..d2925d96 100644 --- a/mithril/framework/codegen/numpy_gen.py +++ b/mithril/framework/codegen/numpy_gen.py @@ -21,7 +21,6 @@ import numpy as np from ...backends.with_manualgrad.numpy_backend import NumpyBackend -from ...core import Dtype from ...framework.physical.model import PhysicalModel from ...framework.utils import find_intersection_type from ...utils.func_utils import is_make_array_required, prepare_function_args @@ -175,11 +174,12 @@ def evaluate_gradients_wrapper_manualgrad( out_data = params[_key] else: out_data = _key_cache["output"] - # dtype = getattr(self.backend, f"float{self.backend.precision}") + assert isinstance(out_data, np.ndarray) - # dtype = getattr(Dtype, f"float{self.backend.precision}") - dtype = Dtype[f"float{self.backend.precision}"] - gradients[key] = self.backend.zeros_like(out_data, dtype=dtype) + + gradients[key] = self.backend.zeros_like( + out_data, dtype=self.backend._dtype + ) if output_gradients is None: if FinalCost not in self.pm._output_keys: diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 446a6e5b..3dc90151 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -234,18 +234,6 @@ def compile_and_compare( # Primitive Model Tests -def test_jax(): - arr = [1.0, 2.0, 3.0] - backends = [ - JaxBackend(dtype=mithril.float16), - JaxBackend(dtype=mithril.float32), - JaxBackend(dtype=mithril.float64), - JaxBackend(dtype=mithril.bfloat16), - ] - for backend in backends: - backend.array(arr) - - def test_buffer_1(): model = Buffer() model.set_types(input=Tensor) @@ -2477,12 +2465,14 @@ def test_cast_int16(): inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ] @@ -2524,12 +2514,14 @@ def test_cast_int32(): inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[Backend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ] @@ -2570,12 +2562,14 @@ def test_cast_int64(): inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[Backend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ] @@ -2614,12 +2608,14 @@ def test_cast_float16(): inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ] @@ -2653,62 +2649,20 @@ def test_cast_float16(): np.testing.assert_allclose(res, reference_outputs["output"]) # type: ignore -# def test_cast_bfloat16(): -# model = Cast(dtype=mithril.bfloat16) -# inp_int = np.array([1, -2, 3], dtype=np.int32) -# inp_float = np.array([1, -2, 3], dtype=np.float32) -# backends: list[TorchBackend | JaxBackend | NumpyBackend | MlxBackend] = [ -# TorchBackend(dtype=mithril.float16), -# TorchBackend(dtype=mithril.bfloat16), -# TorchBackend(dtype=mithril.float32), -# TorchBackend(dtype=mithril.float64), -# JaxBackend(dtype=mithril.float16), -# JaxBackend(dtype=mithril.bfloat16), -# JaxBackend(dtype=mithril.float32), -# JaxBackend(dtype=mithril.float64), -# ] - -# if platform.system() == "Darwin": -# backends += [ -# MlxBackend(dtype=mithril.float16), -# MlxBackend(dtype=mithril.bfloat16), -# MlxBackend(), -# ] - -# expected_dtypes = { -# "torch": torch.bfloat16, -# "jax": jax.numpy.bfloat16, -# "mlx": mx.bfloat16, -# } - -# statics = {"inp_int": inp_int, "inp_float": inp_float} - -# for backend in backends: -# for static in statics.values(): -# _static = backend.array(static) -# pm = mithril.compile( -# model, -# backend, # type: ignore -# constant_keys={"input": _static}, -# inference=True, -# ) -# res = pm.evaluate()["output"] -# assert isinstance(res, backend.DataType) -# assert res.dtype == expected_dtypes[backend.backend_type] - - def test_cast_float32(): model = Cast(dtype=mithril.float32) inp_int = np.array([1, -2, 3], dtype=np.int32) inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[Backend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ] @@ -2749,12 +2703,14 @@ def test_cast_float64(): inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[Backend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ] @@ -2791,12 +2747,14 @@ def test_cast_bool(): inp_float = np.array([1, -2, 3], dtype=np.float32) backends: list[Backend] = [ TorchBackend(dtype=mithril.float16), + TorchBackend(dtype=mithril.bfloat16), TorchBackend(dtype=mithril.float32), TorchBackend(dtype=mithril.float64), NumpyBackend(dtype=mithril.float16), NumpyBackend(dtype=mithril.float32), NumpyBackend(dtype=mithril.float64), JaxBackend(dtype=mithril.float16), + JaxBackend(dtype=mithril.bfloat16), JaxBackend(dtype=mithril.float32), JaxBackend(dtype=mithril.float64), ]