Skip to content

Commit

Permalink
Add Device class
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Jan 27, 2025
1 parent 63cb15b commit 553b38f
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 23 deletions.
3 changes: 2 additions & 1 deletion ndonnx/_refactor/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from spox import Var

from ._dtypes import DType
from ._namespace_info import Device
from ._typed_array import TyArrayBase, astyarray, onnx

if TYPE_CHECKING:
Expand Down Expand Up @@ -316,7 +317,7 @@ def asarray(
/,
*,
dtype: DType | None = None,
device=None,
device: None | Device = None,
copy: bool | None = None,
) -> Array:
if isinstance(obj, Var | str):
Expand Down
44 changes: 31 additions & 13 deletions ndonnx/_refactor/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand All @@ -14,6 +14,7 @@

from ._array import Array, asarray
from ._dtypes import DType
from ._namespace_info import Device
from ._typed_array import funcs as tyfuncs
from ._typed_array import onnx

Expand Down Expand Up @@ -43,7 +44,7 @@ def arange(
step: int | float = 1,
*,
dtype: DType | None = None,
device=None,
device: None | Device = None,
) -> Array:
if dtype is None:
if builtins.all(
Expand Down Expand Up @@ -91,7 +92,9 @@ def nonzero(x: Array, /) -> tuple[Array, ...]:
return tuple(Array._from_tyarray(el) for el in x._tyarray.nonzero())


def astype(x: Array, dtype: DType, /, *, copy: bool = True, device=None) -> Array:
def astype(
x: Array, dtype: DType, /, *, copy: bool = True, device: None | Device = None
) -> Array:
if not copy and x.dtype == dtype:
return x
return x.astype(dtype)
Expand Down Expand Up @@ -263,12 +266,17 @@ def var(


def empty(
shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
return zeros(shape=shape, dtype=dtype)


def empty_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array:
def empty_like(
x: Array, /, *, dtype: DType | None = None, device: None | Device = None
) -> Array:
return zeros_like(x, dtype=dtype)


Expand Down Expand Up @@ -319,7 +327,7 @@ def eye(
*,
k: int = 0,
dtype: DType | None = None,
device=None,
device: None | Device = None,
) -> Array:
nparr = np.eye(n_rows, n_cols, k=k)
return asarray(nparr, dtype=dtype)
Expand All @@ -342,7 +350,7 @@ def full(
fill_value: bool | int | float | str,
*,
dtype: DType | None = None,
device=None,
device: None | Device = None,
) -> Array:
if dtype is None:
if isinstance(fill_value, bool):
Expand Down Expand Up @@ -379,7 +387,7 @@ def full_like(
fill_value: bool | int | float | str,
*,
dtype: DType | None = None,
device=None,
device: None | Device = None,
) -> Array:
shape = x.dynamic_shape
fill = asarray(fill_value, dtype=dtype or x.dtype)
Expand Down Expand Up @@ -417,7 +425,7 @@ def linspace(
num: int,
*,
dtype: DType | None = None,
device=None,
device: None | Device = None,
endpoint: bool = True,
) -> Array:
dtype = dtype or ndx._default_float
Expand All @@ -435,14 +443,19 @@ def matrix_transpose(x: Array, /) -> Array:


def ones(
shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
dtype = dtype or ndx._default_float
shape = (shape,) if isinstance(shape, int) else shape
return Array._from_tyarray(dtype._ones(shape))


def ones_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array:
def ones_like(
x: Array, /, *, dtype: DType | None = None, device: None | Device = None
) -> Array:
dtype = dtype or x.dtype
return full_like(x, 1, dtype=dtype)

Expand Down Expand Up @@ -634,13 +647,18 @@ def where(cond: Array, a: Array, b: Array) -> Array:


def zeros(
shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None
shape: int | tuple[int, ...],
*,
dtype: DType | None = None,
device: None | Device = None,
) -> Array:
dtype = dtype or ndx._default_float
shape = (shape,) if isinstance(shape, int) else shape
return Array._from_tyarray(dtype._zeros(shape))


def zeros_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array:
def zeros_like(
x: Array, /, *, dtype: DType | None = None, device: None | Device = None
) -> Array:
dtype = dtype or x.dtype
return full_like(x, 0, dtype=dtype)
23 changes: 16 additions & 7 deletions ndonnx/_refactor/_namespace_info.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TypedDict

Expand Down Expand Up @@ -51,10 +52,10 @@ def capabilities(self) -> Capabilities:
"max rank": None,
}

def default_device(self) -> None:
return None
def default_device(self) -> Device:
return device

def default_dtypes(self, *, device=None) -> DefaultDataTypes:
def default_dtypes(self, *, device: None | Device = None) -> DefaultDataTypes:
# TODO: We are not standard compliant until we support complex numbers
return { # type: ignore
"real floating": onnx.float64,
Expand All @@ -63,10 +64,10 @@ def default_dtypes(self, *, device=None) -> DefaultDataTypes:
"indexing": onnx.int64,
}

def devices(self) -> list[None]:
raise ValueError("ndonnx does not define devices")
def devices(self) -> list[Device]:
return [device]

def dtypes(self, *, device=None, kind: None | str | tuple[str, ...]) -> DataTypes:
def dtypes(self, *, device: None | Device = None) -> DataTypes:
return {
"bool": onnx.bool_,
"float32": onnx.float32,
Expand All @@ -84,8 +85,16 @@ def dtypes(self, *, device=None, kind: None | str | tuple[str, ...]) -> DataType
}


class Device:
def __eq__(self, other) -> bool:
return self is device


device = Device()


def __array_namespace_info__() -> Info:
return Info()


__all__ = ["__array_namespace_info__"]
__all__ = ["__array_namespace_info__", "device"]
7 changes: 5 additions & 2 deletions ndonnx/_refactor/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import numpy as np

import ndonnx._refactor as ndx
from ndonnx._refactor import _typed_array as tydx
from ndonnx._refactor._typed_array.masked_onnx import TyMaArray

from . import _typed_array as tydx
from ._typed_array.masked_onnx import TyMaArray

Scalar = TypeVar("Scalar", int, float, str)

Expand Down Expand Up @@ -122,6 +123,8 @@ def fill_null(x: ndx.Array, /, value: ndx.Array | Scalar) -> ndx.Array:
A new Array with the null values filled with the given value.
"""
value_ = tydx.astyarray(value)
if isinstance(value_.dtype, ndx.Nullable):
raise ValueError("'fill_null' expects a none-nullable fill value data type.")
xty = x._tyarray
if isinstance(xty, tydx.masked_onnx.TyMaArray):
result_type = ndx.result_type(xty.data.dtype, value_.dtype)
Expand Down

0 comments on commit 553b38f

Please sign in to comment.