From adfc29998cbf53b56743808e091209ef1f78fc05 Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 27 Jan 2025 22:06:10 +0100 Subject: [PATCH] Add Device class --- ndonnx/_refactor/_array.py | 3 +- ndonnx/_refactor/_funcs.py | 44 ++++++++++++++++++++--------- ndonnx/_refactor/_namespace_info.py | 27 ++++++++++++------ ndonnx/_refactor/extensions.py | 7 +++-- 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/ndonnx/_refactor/_array.py b/ndonnx/_refactor/_array.py index abe0dd3..17e0a6e 100644 --- a/ndonnx/_refactor/_array.py +++ b/ndonnx/_refactor/_array.py @@ -18,6 +18,7 @@ from spox import Var from ._dtypes import DType +from ._namespace_info import Device from ._typed_array import TyArrayBase, astyarray, onnx if TYPE_CHECKING: @@ -316,7 +317,7 @@ def asarray( /, *, dtype: DType | None = None, - device=None, + device: None | Device = None, copy: bool | None = None, ) -> Array: if isinstance(obj, Var | str): diff --git a/ndonnx/_refactor/_funcs.py b/ndonnx/_refactor/_funcs.py index 01a229c..d393fa6 100644 --- a/ndonnx/_refactor/_funcs.py +++ b/ndonnx/_refactor/_funcs.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations @@ -14,6 +14,7 @@ from ._array import Array, asarray from ._dtypes import DType +from ._namespace_info import Device from ._typed_array import funcs as tyfuncs from ._typed_array import onnx @@ -43,7 +44,7 @@ def arange( step: int | float = 1, *, dtype: DType | None = None, - device=None, + device: None | Device = None, ) -> Array: if dtype is None: if builtins.all( @@ -91,7 +92,9 @@ def nonzero(x: Array, /) -> tuple[Array, ...]: return tuple(Array._from_tyarray(el) for el in x._tyarray.nonzero()) -def astype(x: Array, dtype: DType, /, *, copy: bool = True, device=None) -> Array: +def astype( + x: Array, dtype: DType, /, *, copy: bool = True, device: None | Device = None +) -> Array: if not copy and x.dtype == dtype: return x return x.astype(dtype) @@ -263,12 +266,17 @@ def var( def empty( - shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None + shape: int | tuple[int, ...], + *, + dtype: DType | None = None, + device: None | Device = None, ) -> Array: return zeros(shape=shape, dtype=dtype) -def empty_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array: +def empty_like( + x: Array, /, *, dtype: DType | None = None, device: None | Device = None +) -> Array: return zeros_like(x, dtype=dtype) @@ -319,7 +327,7 @@ def eye( *, k: int = 0, dtype: DType | None = None, - device=None, + device: None | Device = None, ) -> Array: nparr = np.eye(n_rows, n_cols, k=k) return asarray(nparr, dtype=dtype) @@ -342,7 +350,7 @@ def full( fill_value: bool | int | float | str, *, dtype: DType | None = None, - device=None, + device: None | Device = None, ) -> Array: if dtype is None: if isinstance(fill_value, bool): @@ -379,7 +387,7 @@ def full_like( fill_value: bool | int | float | str, *, dtype: DType | None = None, - device=None, + device: None | Device = None, ) -> Array: shape = x.dynamic_shape fill = asarray(fill_value, dtype=dtype or x.dtype) @@ -417,7 +425,7 @@ def linspace( num: int, *, dtype: DType | None = None, - device=None, + device: None | Device = None, endpoint: bool = True, ) -> Array: dtype = dtype or ndx._default_float @@ -435,14 +443,19 @@ def matrix_transpose(x: Array, /) -> Array: def ones( - shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None + shape: int | tuple[int, ...], + *, + dtype: DType | None = None, + device: None | Device = None, ) -> Array: dtype = dtype or ndx._default_float shape = (shape,) if isinstance(shape, int) else shape return Array._from_tyarray(dtype._ones(shape)) -def ones_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array: +def ones_like( + x: Array, /, *, dtype: DType | None = None, device: None | Device = None +) -> Array: dtype = dtype or x.dtype return full_like(x, 1, dtype=dtype) @@ -634,13 +647,18 @@ def where(cond: Array, a: Array, b: Array) -> Array: def zeros( - shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None + shape: int | tuple[int, ...], + *, + dtype: DType | None = None, + device: None | Device = None, ) -> Array: dtype = dtype or ndx._default_float shape = (shape,) if isinstance(shape, int) else shape return Array._from_tyarray(dtype._zeros(shape)) -def zeros_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array: +def zeros_like( + x: Array, /, *, dtype: DType | None = None, device: None | Device = None +) -> Array: dtype = dtype or x.dtype return full_like(x, 0, dtype=dtype) diff --git a/ndonnx/_refactor/_namespace_info.py b/ndonnx/_refactor/_namespace_info.py index 7f3341d..8d0e0ce 100644 --- a/ndonnx/_refactor/_namespace_info.py +++ b/ndonnx/_refactor/_namespace_info.py @@ -1,5 +1,6 @@ -# Copyright (c) QuantCo 2023-2024 +# Copyright (c) QuantCo 2023-2025 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations from typing import TypedDict @@ -51,10 +52,10 @@ def capabilities(self) -> Capabilities: "max rank": None, } - def default_device(self) -> None: - return None + def default_device(self) -> Device: + return device - def default_dtypes(self, *, device=None) -> DefaultDataTypes: + def default_dtypes(self, *, device: None | Device = None) -> DefaultDataTypes: # TODO: We are not standard compliant until we support complex numbers return { # type: ignore "real floating": onnx.float64, @@ -63,10 +64,12 @@ def default_dtypes(self, *, device=None) -> DefaultDataTypes: "indexing": onnx.int64, } - def devices(self) -> list[None]: - raise ValueError("ndonnx does not define devices") + def devices(self) -> list[Device]: + return [device] - def dtypes(self, *, device=None, kind: None | str | tuple[str, ...]) -> DataTypes: + def dtypes( + self, *, device: None | Device = None, kind: None | str | tuple[str, ...] + ) -> DataTypes: return { "bool": onnx.bool_, "float32": onnx.float32, @@ -84,8 +87,16 @@ def dtypes(self, *, device=None, kind: None | str | tuple[str, ...]) -> DataType } +class Device: + def __eq__(self, other) -> bool: + return self is device + + +device = Device() + + def __array_namespace_info__() -> Info: return Info() -__all__ = ["__array_namespace_info__"] +__all__ = ["__array_namespace_info__", "device"] diff --git a/ndonnx/_refactor/extensions.py b/ndonnx/_refactor/extensions.py index 5115ebe..af5c189 100644 --- a/ndonnx/_refactor/extensions.py +++ b/ndonnx/_refactor/extensions.py @@ -10,8 +10,9 @@ import numpy as np import ndonnx._refactor as ndx -from ndonnx._refactor import _typed_array as tydx -from ndonnx._refactor._typed_array.masked_onnx import TyMaArray + +from . import _typed_array as tydx +from ._typed_array.masked_onnx import TyMaArray Scalar = TypeVar("Scalar", int, float, str) @@ -122,6 +123,8 @@ def fill_null(x: ndx.Array, /, value: ndx.Array | Scalar) -> ndx.Array: A new Array with the null values filled with the given value. """ value_ = tydx.astyarray(value) + if isinstance(value_.dtype, ndx.Nullable): + raise ValueError("'fill_null' expects a none-nullable fill value data type.") xty = x._tyarray if isinstance(xty, tydx.masked_onnx.TyMaArray): result_type = ndx.result_type(xty.data.dtype, value_.dtype)