diff --git a/mithril/cores/python/jax/ops.py b/mithril/cores/python/jax/ops.py index 12498925..42cd4a26 100644 --- a/mithril/cores/python/jax/ops.py +++ b/mithril/cores/python/jax/ops.py @@ -221,6 +221,7 @@ "dtype", "zeros_like", "avg_pool2d", + "ones", ] @@ -1137,11 +1138,30 @@ def zeros_like(input: jax.Array) -> jax.Array: return jnp.zeros_like(input) +def ones( + shape: tuple[int, ...], + *, + dtype: jnp.dtype[Any] | None = None, + device: str, + default_dtype: str, +) -> jax.Array: + dtype = dtype_map[default_dtype] if dtype is None else dtype + with jax.default_device(get_device(device)): + return jnp.ones(shape, dtype=dtype) + + def atleast_1d(input: jax.Array) -> jax.Array: return jnp.atleast_1d(input) -array_creation_funcs = ["arange", "randn", "to_tensor", "eye", "ones_with_zero_diag"] +array_creation_funcs = [ + "arange", + "randn", + "ones", + "to_tensor", + "eye", + "ones_with_zero_diag", +] primitive_func_dict = common_primitive_func_dict = { key: fn for key, fn in globals().items() if callable(fn) } | common_primitive_func_dict diff --git a/mithril/cores/python/mlx/ops.py b/mithril/cores/python/mlx/ops.py index b6b4bf64..3193c4db 100644 --- a/mithril/cores/python/mlx/ops.py +++ b/mithril/cores/python/mlx/ops.py @@ -71,9 +71,9 @@ tuple_converter, union, ) +from . import utils # from ...backends.with_autograd.mlx_backend import utils -from . import utils AxisType = None | int | Sequence[int] @@ -198,6 +198,7 @@ "dtype", "zeros_like", "avg_pool2d", + "ones", ] @@ -1018,11 +1019,29 @@ def zeros_like(input: mx.array) -> mx.array: return mx.zeros_like(input) +def ones( + shape: tuple[int, ...], + *, + dtype: mx.Dtype | None = None, + device: str, + default_dtype: str, +) -> mx.array: + dtype = utils.dtype_map[default_dtype] if dtype is None else dtype + return mx.ones(shape, dtype=dtype) + + def atleast_1d(input: mx.array) -> mx.array: return mx.atleast_1d(input) -array_creation_funcs = ["arange", "randn", "to_tensor", "eye", "ones_with_zero_diag"] +array_creation_funcs = [ + "arange", + "randn", + "to_tensor", + "eye", + "ones_with_zero_diag", + "ones", +] primitive_func_dict = common_primitive_func_dict = { key: fn for key, fn in globals().items() if callable(fn) } | common_primitive_func_dict diff --git a/mithril/cores/python/numpy/ops.py b/mithril/cores/python/numpy/ops.py index 9ab75ea4..c4ec26c3 100644 --- a/mithril/cores/python/numpy/ops.py +++ b/mithril/cores/python/numpy/ops.py @@ -169,6 +169,7 @@ "dtype", "zeros_like", "avg_pool2d", + "ones", ] @@ -1319,6 +1320,17 @@ def zeros_like( return np.zeros_like(input) +def ones( + shape: tuple[int, ...], + *, + dtype: np.dtype[Any] | None = None, + default_dtype: str, + cache: CacheType | None = None, +) -> np.ndarray[Any, Any]: + dtype = dtype_map[default_dtype] if dtype is None else dtype + return np.ones(shape, dtype=dtype) + + def atleast_1d( input: np.ndarray[Any, Any], cache: CacheType | None = None ) -> np.ndarray[Any, Any]: @@ -1698,6 +1710,7 @@ def cartesian_diff( array_creation_funcs = [ "arange", "randn", + "ones", "to_tensor", "make_array", "eye", diff --git a/mithril/cores/python/torch/ops.py b/mithril/cores/python/torch/ops.py index b97f9896..c49845ea 100644 --- a/mithril/cores/python/torch/ops.py +++ b/mithril/cores/python/torch/ops.py @@ -209,6 +209,7 @@ "dtype", "zeros_like", "avg_pool2d", + "ones", ] @@ -1221,6 +1222,17 @@ def zeros_like(input: torch.Tensor) -> torch.Tensor: return torch.zeros_like(input) +def ones( + shape: tuple[int, ...], + *, + dtype: torch.dtype | None = None, + device: str, + default_dtype: str, +) -> torch.Tensor: + dtype = dtype_map[default_dtype] if dtype is None else dtype + return torch.ones(shape, device=device, dtype=dtype) + + def atleast_1d(input: torch.Tensor) -> torch.Tensor: return torch.atleast_1d(input) # type: ignore @@ -1229,7 +1241,14 @@ def primitive_embedding(input: torch.Tensor, weight: torch.Tensor) -> torch.Tens return weight[input.long()] -array_creation_funcs = ["arange", "randn", "to_tensor", "eye", "ones_with_zero_diag"] +array_creation_funcs = [ + "arange", + "randn", + "ones", + "to_tensor", + "eye", + "ones_with_zero_diag", +] primitive_func_dict = common_primitive_func_dict | { key: fn for key, fn in globals().items() if callable(fn) } diff --git a/mithril/models/primitives.py b/mithril/models/primitives.py index 10eb2eb7..2615a60d 100644 --- a/mithril/models/primitives.py +++ b/mithril/models/primitives.py @@ -177,6 +177,7 @@ "Trapezoid", "Pad", "Randn", + "Ones", "PrimitiveModel", "Buffer", "ToTuple", @@ -2675,6 +2676,35 @@ def __call__( # type: ignore[override] return super().__call__(input=input, output=output) +class Ones(PrimitiveModel): + shape: Connection + dtype: Connection + output: Connection + + def __init__( + self, + shape: tuple[int, ...] | list[int] | ToBeDetermined = TBD, + dtype: types.Dtype | None = None, + *, + name: str | None = None, + ) -> None: + super().__init__( + formula_key="ones", + name=name, + output=BaseKey(shape=[("Var", ...)], type=Tensor), + shape=BaseKey(type=tuple[int, ...] | list[int], value=shape), + dtype=BaseKey(type=types.Dtype | None, value=dtype), + ) + + def __call__( # type: ignore[override] + self, + shape: ConnectionType = NOT_GIVEN, + dtype: ConnectionType = NOT_GIVEN, + output: ConnectionType = NOT_GIVEN, + ) -> ExtendInfo: + return super().__call__(shape=shape, dtype=dtype, output=output) + + class Buffer(OperatorModel): input: Connection output: Connection diff --git a/tests/scripts/test_all_models.py b/tests/scripts/test_all_models.py index 065af6f0..799eceff 100644 --- a/tests/scripts/test_all_models.py +++ b/tests/scripts/test_all_models.py @@ -62,6 +62,7 @@ NanToNum, NormModifier, NotEqual, + Ones, PrimitiveUnion, Prod, Randn, @@ -1532,6 +1533,98 @@ def test_zeros_like(): ) +def test_ones_static_shape(): + model = Ones(shape=(2, 3, 4)) + + reference_outputs = {"output": list_full(1.0, 2, 3, 4)} + compile_and_compare( + model=model, + compile_kwargs={"inference": True}, + data={}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + tolerances=1e-6, + assert_shapes=False, + ) + + +def test_ones_dynamic_shape(): + model = Ones(shape=TBD) + + reference_outputs = {"output": list_full(1.0, 3, 5, 2)} + compile_and_compare( + model=model, + compile_kwargs={"jit": False, "inference": True}, + data={"shape": (3, 5, 2)}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + tolerances=1e-6, + assert_shapes=False, + ignore_transform={"shape"}, + ) + + +def test_ones_static_with_dtype(): + dtypes = [mithril.float16, mithril.float32] + for dtype in dtypes: + backends: list[Backend[Any]] = [ + TorchBackend(dtype=dtype), + NumpyBackend(dtype=dtype), + JaxBackend(dtype=dtype), + ] + if platform.system() == "Darwin": + backends.append(MlxBackend(dtype=dtype)) + + model = Ones(shape=(2, 4), dtype=dtype) + + reference_outputs = {"output": list_full(1.0, 2, 4)} + compile_and_compare( + model=model, + compile_kwargs={"inference": True}, + data={}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + tolerances=1e-6, + assert_shapes=False, + backends=backends, + ) + + +def test_ones_dynamic_with_dtype(): + dtypes = [mithril.float16, mithril.float32] + for dtype in dtypes: + backends: list[Backend[Any]] = [ + TorchBackend(dtype=dtype), + NumpyBackend(dtype=dtype), + JaxBackend(dtype=dtype), + ] + if platform.system() == "Darwin": + backends.append(MlxBackend(dtype=dtype)) + + model = Ones(shape=TBD, dtype=dtype) + + reference_outputs = {"output": list_full(1.0, 3, 2)} + compile_and_compare( + model=model, + compile_kwargs={"jit": False, "inference": True}, + data={"shape": (3, 2)}, + params={}, + output_gradients={}, + reference_outputs=reference_outputs, + reference_gradients=None, + tolerances=1e-6, + assert_shapes=False, + backends=backends, + ignore_transform={"shape"}, + ) + + def test_eye_complement_1(): model = EyeComplement(N=2, M=2)