Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ones primitive #227

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion mithril/cores/python/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@
"dtype",
"zeros_like",
"avg_pool2d",
"ones",
]


Expand Down Expand Up @@ -1137,11 +1138,30 @@ def zeros_like(input: jax.Array) -> jax.Array:
return jnp.zeros_like(input)


def ones(
shape: tuple[int, ...],
*,
dtype: jnp.dtype[Any] | None = None,
device: str,
default_dtype: str,
) -> jax.Array:
dtype = dtype_map[default_dtype] if dtype is None else dtype
with jax.default_device(get_device(device)):
return jnp.ones(shape, dtype=dtype)


def atleast_1d(input: jax.Array) -> jax.Array:
return jnp.atleast_1d(input)


array_creation_funcs = ["arange", "randn", "to_tensor", "eye", "ones_with_zero_diag"]
array_creation_funcs = [
"arange",
"randn",
"ones",
"to_tensor",
"eye",
"ones_with_zero_diag",
]
primitive_func_dict = common_primitive_func_dict = {
key: fn for key, fn in globals().items() if callable(fn)
} | common_primitive_func_dict
23 changes: 21 additions & 2 deletions mithril/cores/python/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@
tuple_converter,
union,
)
from . import utils

# from ...backends.with_autograd.mlx_backend import utils
from . import utils

AxisType = None | int | Sequence[int]

Expand Down Expand Up @@ -198,6 +198,7 @@
"dtype",
"zeros_like",
"avg_pool2d",
"ones",
]


Expand Down Expand Up @@ -1018,11 +1019,29 @@ def zeros_like(input: mx.array) -> mx.array:
return mx.zeros_like(input)


def ones(
shape: tuple[int, ...],
*,
dtype: mx.Dtype | None = None,
device: str,
default_dtype: str,
) -> mx.array:
dtype = utils.dtype_map[default_dtype] if dtype is None else dtype
return mx.ones(shape, dtype=dtype)


def atleast_1d(input: mx.array) -> mx.array:
return mx.atleast_1d(input)


array_creation_funcs = ["arange", "randn", "to_tensor", "eye", "ones_with_zero_diag"]
array_creation_funcs = [
"arange",
"randn",
"to_tensor",
"eye",
"ones_with_zero_diag",
"ones",
]
primitive_func_dict = common_primitive_func_dict = {
key: fn for key, fn in globals().items() if callable(fn)
} | common_primitive_func_dict
13 changes: 13 additions & 0 deletions mithril/cores/python/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@
"dtype",
"zeros_like",
"avg_pool2d",
"ones",
]


Expand Down Expand Up @@ -1319,6 +1320,17 @@ def zeros_like(
return np.zeros_like(input)


def ones(
shape: tuple[int, ...],
*,
dtype: np.dtype[Any] | None = None,
default_dtype: str,
cache: CacheType | None = None,
) -> np.ndarray[Any, Any]:
dtype = dtype_map[default_dtype] if dtype is None else dtype
return np.ones(shape, dtype=dtype)


def atleast_1d(
input: np.ndarray[Any, Any], cache: CacheType | None = None
) -> np.ndarray[Any, Any]:
Expand Down Expand Up @@ -1698,6 +1710,7 @@ def cartesian_diff(
array_creation_funcs = [
"arange",
"randn",
"ones",
"to_tensor",
"make_array",
"eye",
Expand Down
21 changes: 20 additions & 1 deletion mithril/cores/python/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
"dtype",
"zeros_like",
"avg_pool2d",
"ones",
]


Expand Down Expand Up @@ -1221,6 +1222,17 @@ def zeros_like(input: torch.Tensor) -> torch.Tensor:
return torch.zeros_like(input)


def ones(
shape: tuple[int, ...],
*,
dtype: torch.dtype | None = None,
device: str,
default_dtype: str,
) -> torch.Tensor:
dtype = dtype_map[default_dtype] if dtype is None else dtype
return torch.ones(shape, device=device, dtype=dtype)


def atleast_1d(input: torch.Tensor) -> torch.Tensor:
return torch.atleast_1d(input) # type: ignore

Expand All @@ -1229,7 +1241,14 @@ def primitive_embedding(input: torch.Tensor, weight: torch.Tensor) -> torch.Tens
return weight[input.long()]


array_creation_funcs = ["arange", "randn", "to_tensor", "eye", "ones_with_zero_diag"]
array_creation_funcs = [
"arange",
"randn",
"ones",
"to_tensor",
"eye",
"ones_with_zero_diag",
]
primitive_func_dict = common_primitive_func_dict | {
key: fn for key, fn in globals().items() if callable(fn)
}
30 changes: 30 additions & 0 deletions mithril/models/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
"Trapezoid",
"Pad",
"Randn",
"Ones",
"PrimitiveModel",
"Buffer",
"ToTuple",
Expand Down Expand Up @@ -2675,6 +2676,35 @@ def __call__( # type: ignore[override]
return super().__call__(input=input, output=output)


class Ones(PrimitiveModel):
shape: Connection
dtype: Connection
output: Connection

def __init__(
self,
shape: tuple[int, ...] | list[int] | ToBeDetermined = TBD,
dtype: types.Dtype | None = None,
*,
name: str | None = None,
) -> None:
super().__init__(
formula_key="ones",
name=name,
output=BaseKey(shape=[("Var", ...)], type=Tensor),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to indicate shape if it is Variadic

shape=BaseKey(type=tuple[int, ...] | list[int], value=shape),
dtype=BaseKey(type=types.Dtype | None, value=dtype),
)

def __call__( # type: ignore[override]
self,
shape: ConnectionType = NOT_GIVEN,
dtype: ConnectionType = NOT_GIVEN,
output: ConnectionType = NOT_GIVEN,
) -> ExtendInfo:
return super().__call__(shape=shape, dtype=dtype, output=output)


class Buffer(OperatorModel):
input: Connection
output: Connection
Expand Down
93 changes: 93 additions & 0 deletions tests/scripts/test_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
NanToNum,
NormModifier,
NotEqual,
Ones,
PrimitiveUnion,
Prod,
Randn,
Expand Down Expand Up @@ -1532,6 +1533,98 @@ def test_zeros_like():
)


def test_ones_static_shape():
model = Ones(shape=(2, 3, 4))

reference_outputs = {"output": list_full(1.0, 2, 3, 4)}
compile_and_compare(
model=model,
compile_kwargs={"inference": True},
data={},
params={},
output_gradients={},
reference_outputs=reference_outputs,
reference_gradients=None,
tolerances=1e-6,
assert_shapes=False,
)


def test_ones_dynamic_shape():
model = Ones(shape=TBD)

reference_outputs = {"output": list_full(1.0, 3, 5, 2)}
compile_and_compare(
model=model,
compile_kwargs={"jit": False, "inference": True},
data={"shape": (3, 5, 2)},
params={},
output_gradients={},
reference_outputs=reference_outputs,
reference_gradients=None,
tolerances=1e-6,
assert_shapes=False,
ignore_transform={"shape"},
)


def test_ones_static_with_dtype():
dtypes = [mithril.float16, mithril.float32]
for dtype in dtypes:
backends: list[Backend[Any]] = [
TorchBackend(dtype=dtype),
NumpyBackend(dtype=dtype),
JaxBackend(dtype=dtype),
]
if platform.system() == "Darwin":
backends.append(MlxBackend(dtype=dtype))

model = Ones(shape=(2, 4), dtype=dtype)

reference_outputs = {"output": list_full(1.0, 2, 4)}
compile_and_compare(
model=model,
compile_kwargs={"inference": True},
data={},
params={},
output_gradients={},
reference_outputs=reference_outputs,
reference_gradients=None,
tolerances=1e-6,
assert_shapes=False,
backends=backends,
)


def test_ones_dynamic_with_dtype():
dtypes = [mithril.float16, mithril.float32]
for dtype in dtypes:
backends: list[Backend[Any]] = [
TorchBackend(dtype=dtype),
NumpyBackend(dtype=dtype),
JaxBackend(dtype=dtype),
]
if platform.system() == "Darwin":
backends.append(MlxBackend(dtype=dtype))

model = Ones(shape=TBD, dtype=dtype)

reference_outputs = {"output": list_full(1.0, 3, 2)}
compile_and_compare(
model=model,
compile_kwargs={"jit": False, "inference": True},
data={"shape": (3, 2)},
params={},
output_gradients={},
reference_outputs=reference_outputs,
reference_gradients=None,
tolerances=1e-6,
assert_shapes=False,
backends=backends,
ignore_transform={"shape"},
)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add randomized tests for this model

def test_eye_complement_1():
model = EyeComplement(N=2, M=2)

Expand Down