diff --git a/ndonnx/_logic_in_data/__init__.py b/ndonnx/_logic_in_data/__init__.py index 21be632..2e29574 100644 --- a/ndonnx/_logic_in_data/__init__.py +++ b/ndonnx/_logic_in_data/__init__.py @@ -34,14 +34,16 @@ where, ) from .elementwise import ( + abs, isfinite, isnan, ) from .binary_functions import add, equal, maximum from .infos import finfo, iinfo - +from .namespace_info import __array_namespace_info__ __all__ = [ + "__array_namespace_info__", "Array", "int8", "int16", @@ -56,6 +58,7 @@ "float64", "bool", "DType", + "abs", "all", "arange", "asarray", diff --git a/ndonnx/_logic_in_data/_typed_array/core.py b/ndonnx/_logic_in_data/_typed_array/core.py index 0882326..7ea20b1 100644 --- a/ndonnx/_logic_in_data/_typed_array/core.py +++ b/ndonnx/_logic_in_data/_typed_array/core.py @@ -181,6 +181,12 @@ def __add__(self, rhs: TyArrayBase) -> TyArrayBase: def __radd__(self, lhs: TyArrayBase) -> TyArrayBase: return _promote_and_apply_op(self, lhs, operator.add, op.add, False) + def __ge__(self, rhs: TyArrayBase, /) -> TyArrayBase: + return _promote_and_apply_op(self, rhs, operator.ge, op.greater_or_equal, False) + + def __gt__(self, rhs: TyArrayBase, /) -> TyArrayBase: + return _promote_and_apply_op(self, rhs, operator.gt, op.greater, False) + def __truediv__(self, rhs: TyArrayBase) -> TyArrayBase: return _promote_and_apply_op(self, rhs, operator.truediv, op.div, True) diff --git a/ndonnx/_logic_in_data/_typed_array/py_scalars.py b/ndonnx/_logic_in_data/_typed_array/py_scalars.py index 8c156d2..c24bb90 100644 --- a/ndonnx/_logic_in_data/_typed_array/py_scalars.py +++ b/ndonnx/_logic_in_data/_typed_array/py_scalars.py @@ -67,6 +67,18 @@ def __add__(self, rhs: TyArrayBase) -> TyArrayBase: def __radd__(self, lhs: TyArrayBase) -> TyArrayBase: return _promote_and_apply_op(self, lhs, operator.add, False) + def __le__(self, rhs: TyArrayBase, /) -> TyArrayBase: + return _promote_and_apply_op(self, rhs, operator.le, False) + + def __lt__(self, rhs: TyArrayBase, /) -> TyArrayBase: + return _promote_and_apply_op(self, rhs, operator.lt, False) + + def __ge__(self, rhs: TyArrayBase, /) -> TyArrayBase: + return _promote_and_apply_op(self, rhs, operator.ge, False) + + def __gt__(self, rhs: TyArrayBase, /) -> TyArrayBase: + return _promote_and_apply_op(self, rhs, operator.gt, False) + def __mul__(self, rhs: TyArrayBase) -> TyArrayBase: return _promote_and_apply_op(self, rhs, operator.mul, True) diff --git a/ndonnx/_logic_in_data/_typed_array/typed_array.py b/ndonnx/_logic_in_data/_typed_array/typed_array.py index 2c08384..f9900d6 100644 --- a/ndonnx/_logic_in_data/_typed_array/typed_array.py +++ b/ndonnx/_logic_in_data/_typed_array/typed_array.py @@ -160,6 +160,9 @@ def log2(self) -> Self: raise ValueError(f"'log2' is not implemented for {self.dtype}") # Dunder-functions + def __abs__(self) -> TyArrayBase: + raise ValueError(f"'__abs__' is not implemented for {self.dtype}") + def __add__(self, other: TyArrayBase) -> TyArrayBase: return NotImplemented @@ -187,6 +190,12 @@ def _eqcomp(self, other: TyArrayBase) -> TyArrayBase | NotImplementedType: """ ... + def __ge__(self, other: TyArrayBase, /) -> TyArrayBase: + return NotImplemented + + def __gt__(self, other: TyArrayBase, /) -> TyArrayBase: + return NotImplemented + def __invert__(self) -> TyArrayBase: return NotImplemented diff --git a/ndonnx/_logic_in_data/array.py b/ndonnx/_logic_in_data/array.py index a7df68a..4182d97 100644 --- a/ndonnx/_logic_in_data/array.py +++ b/ndonnx/_logic_in_data/array.py @@ -159,7 +159,8 @@ def __int__(self: Array, /) -> int: ################################################################## def __abs__(self: Array, /) -> Array: - raise NotImplementedError + data = self._data.__abs__() + return Array._from_data(data) def __add__(self: Array, rhs: int | float | Array, /) -> Array: return _apply_op(self, rhs, std_ops.add) @@ -182,13 +183,13 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: # t return _apply_op(self, other, std_ops.eq) def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array: - raise NotImplementedError + return _apply_op(self, other, std_ops.floordiv) def __ge__(self: Array, other: Union[int, float, Array], /) -> Array: - raise NotImplementedError + return _apply_op(self, other, std_ops.ge) def __gt__(self: Array, other: Union[int, float, Array], /) -> Array: - raise NotImplementedError + return _apply_op(self, other, std_ops.gt) def __invert__(self: Array, /) -> Array: raise NotImplementedError diff --git a/ndonnx/_logic_in_data/elementwise.py b/ndonnx/_logic_in_data/elementwise.py index 8d8701a..88a83d0 100644 --- a/ndonnx/_logic_in_data/elementwise.py +++ b/ndonnx/_logic_in_data/elementwise.py @@ -8,6 +8,10 @@ from .array import Array +def abs(array: Array, /) -> Array: + return Array._from_data(array._data.__abs__()) + + def acos(array: Array, /) -> Array: return Array._from_data(array._data.acos()) diff --git a/ndonnx/_logic_in_data/namespace_info.py b/ndonnx/_logic_in_data/namespace_info.py new file mode 100644 index 0000000..4eb3dfa --- /dev/null +++ b/ndonnx/_logic_in_data/namespace_info.py @@ -0,0 +1,91 @@ +# Copyright (c) QuantCo 2023-2024 +# SPDX-License-Identifier: BSD-3-Clause + +from typing import TypedDict + +from . import dtypes + +DefaultDataTypes = TypedDict( + "DefaultDataTypes", + { + "real floating": dtypes.DType, + "complex floating": dtypes.DType, + "integral": dtypes.DType, + "indexing": dtypes.DType, + }, +) + + +class DataTypes(TypedDict, total=False): + bool: dtypes.DType + float32: dtypes.DType + float64: dtypes.DType + int8: dtypes.DType + int16: dtypes.DType + int32: dtypes.DType + int64: dtypes.DType + uint8: dtypes.DType + uint16: dtypes.DType + uint32: dtypes.DType + uint64: dtypes.DType + + +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max rank": None | int, + }, +) + + +class Info: + """Namespace returned by `__array_namespace_info__`.""" + + def capabilities(self) -> Capabilities: + return { + "boolean indexing": True, + "data-dependent shapes": True, + "max rank": None, + } + + def default_device(self) -> None: + raise ValueError("ndonnx does not define a default device") + ... + + def default_dtypes(self, *, device: None) -> DefaultDataTypes: + # TODO: We are not standard compliant until we support complex numbers + return { # type: ignore + "real floating": dtypes.default_float, + # "complex floating": dtypes.complex128, + "integral": dtypes.default_int, + "indexing": dtypes.default_int, + } + + def devices(self) -> list[None]: + raise ValueError("ndonnx does not define devices") + + def dtypes(self, *, device: None, kind: None | str | tuple[str, ...]) -> DataTypes: + return { + "bool": dtypes.bool_, + "float32": dtypes.float32, + "float64": dtypes.float64, + # "complex64": dtypes.DType, + # "complex128": dtypes.DType, + "int8": dtypes.int8, + "int16": dtypes.int16, + "int32": dtypes.int32, + "int64": dtypes.int64, + "uint8": dtypes.uint8, + "uint16": dtypes.uint16, + "uint32": dtypes.uint32, + "uint64": dtypes.uint64, + } + + +def __array_namespace_info__() -> Info: + return Info() + + +__all__ = ["__array_namespace_info__"] diff --git a/pyproject.toml b/pyproject.toml index 9cbda55..9e6ad31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ ignore = [ "N806", # https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function "E501", # https://docs.astral.sh/ruff/faq/#is-the-ruff-linter-compatible-with-black "UP038", # https://github.com/astral-sh/ruff/issues/7871 + "N807", # Free functions may start/end with dunders __array_namespace_info__ "UP007", ] select = [