Skip to content

Commit

Permalink
Add ge, gt, abs, and array_namespace_info
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 22, 2024
1 parent 9031e27 commit be6a817
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 5 deletions.
5 changes: 4 additions & 1 deletion ndonnx/_logic_in_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@
where,
)
from .elementwise import (
abs,
isfinite,
isnan,
)
from .binary_functions import add, equal, maximum
from .infos import finfo, iinfo

from .namespace_info import __array_namespace_info__

__all__ = [
"__array_namespace_info__",
"Array",
"int8",
"int16",
Expand All @@ -56,6 +58,7 @@
"float64",
"bool",
"DType",
"abs",
"all",
"arange",
"asarray",
Expand Down
6 changes: 6 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def __add__(self, rhs: TyArrayBase) -> TyArrayBase:
def __radd__(self, lhs: TyArrayBase) -> TyArrayBase:
return _promote_and_apply_op(self, lhs, operator.add, op.add, False)

def __ge__(self, rhs: TyArrayBase, /) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.ge, op.greater_or_equal, False)

def __gt__(self, rhs: TyArrayBase, /) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.gt, op.greater, False)

def __truediv__(self, rhs: TyArrayBase) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.truediv, op.div, True)

Expand Down
12 changes: 12 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/py_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def __add__(self, rhs: TyArrayBase) -> TyArrayBase:
def __radd__(self, lhs: TyArrayBase) -> TyArrayBase:
return _promote_and_apply_op(self, lhs, operator.add, False)

def __le__(self, rhs: TyArrayBase, /) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.le, False)

def __lt__(self, rhs: TyArrayBase, /) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.lt, False)

def __ge__(self, rhs: TyArrayBase, /) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.ge, False)

def __gt__(self, rhs: TyArrayBase, /) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.gt, False)

def __mul__(self, rhs: TyArrayBase) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.mul, True)

Expand Down
9 changes: 9 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def log2(self) -> Self:
raise ValueError(f"'log2' is not implemented for {self.dtype}")

# Dunder-functions
def __abs__(self) -> TyArrayBase:
raise ValueError(f"'__abs__' is not implemented for {self.dtype}")

def __add__(self, other: TyArrayBase) -> TyArrayBase:
return NotImplemented

Expand Down Expand Up @@ -187,6 +190,12 @@ def _eqcomp(self, other: TyArrayBase) -> TyArrayBase | NotImplementedType:
"""
...

def __ge__(self, other: TyArrayBase, /) -> TyArrayBase:
return NotImplemented

def __gt__(self, other: TyArrayBase, /) -> TyArrayBase:
return NotImplemented

def __invert__(self) -> TyArrayBase:
return NotImplemented

Expand Down
9 changes: 5 additions & 4 deletions ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def __int__(self: Array, /) -> int:
##################################################################

def __abs__(self: Array, /) -> Array:
raise NotImplementedError
data = self._data.__abs__()
return Array._from_data(data)

def __add__(self: Array, rhs: int | float | Array, /) -> Array:
return _apply_op(self, rhs, std_ops.add)
Expand All @@ -182,13 +183,13 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: # t
return _apply_op(self, other, std_ops.eq)

def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
raise NotImplementedError
return _apply_op(self, other, std_ops.floordiv)

def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
raise NotImplementedError
return _apply_op(self, other, std_ops.ge)

def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
raise NotImplementedError
return _apply_op(self, other, std_ops.gt)

def __invert__(self: Array, /) -> Array:
raise NotImplementedError
Expand Down
4 changes: 4 additions & 0 deletions ndonnx/_logic_in_data/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from .array import Array


def abs(array: Array, /) -> Array:
return Array._from_data(array._data.__abs__())


def acos(array: Array, /) -> Array:
return Array._from_data(array._data.acos())

Expand Down
91 changes: 91 additions & 0 deletions ndonnx/_logic_in_data/namespace_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from typing import TypedDict

from . import dtypes

DefaultDataTypes = TypedDict(
"DefaultDataTypes",
{
"real floating": dtypes.DType,
"complex floating": dtypes.DType,
"integral": dtypes.DType,
"indexing": dtypes.DType,
},
)


class DataTypes(TypedDict, total=False):
bool: dtypes.DType
float32: dtypes.DType
float64: dtypes.DType
int8: dtypes.DType
int16: dtypes.DType
int32: dtypes.DType
int64: dtypes.DType
uint8: dtypes.DType
uint16: dtypes.DType
uint32: dtypes.DType
uint64: dtypes.DType


Capabilities = TypedDict(
"Capabilities",
{
"boolean indexing": bool,
"data-dependent shapes": bool,
"max rank": None | int,
},
)


class Info:
"""Namespace returned by `__array_namespace_info__`."""

def capabilities(self) -> Capabilities:
return {
"boolean indexing": True,
"data-dependent shapes": True,
"max rank": None,
}

def default_device(self) -> None:
raise ValueError("ndonnx does not define a default device")
...

def default_dtypes(self, *, device: None) -> DefaultDataTypes:
# TODO: We are not standard compliant until we support complex numbers
return { # type: ignore
"real floating": dtypes.default_float,
# "complex floating": dtypes.complex128,
"integral": dtypes.default_int,
"indexing": dtypes.default_int,
}

def devices(self) -> list[None]:
raise ValueError("ndonnx does not define devices")

def dtypes(self, *, device: None, kind: None | str | tuple[str, ...]) -> DataTypes:
return {
"bool": dtypes.bool_,
"float32": dtypes.float32,
"float64": dtypes.float64,
# "complex64": dtypes.DType,
# "complex128": dtypes.DType,
"int8": dtypes.int8,
"int16": dtypes.int16,
"int32": dtypes.int32,
"int64": dtypes.int64,
"uint8": dtypes.uint8,
"uint16": dtypes.uint16,
"uint32": dtypes.uint32,
"uint64": dtypes.uint64,
}


def __array_namespace_info__() -> Info:
return Info()


__all__ = ["__array_namespace_info__"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ ignore = [
"N806", # https://docs.astral.sh/ruff/rules/non-lowercase-variable-in-function
"E501", # https://docs.astral.sh/ruff/faq/#is-the-ruff-linter-compatible-with-black
"UP038", # https://github.com/astral-sh/ruff/issues/7871
"N807", # Free functions may start/end with dunders __array_namespace_info__
"UP007",
]
select = [
Expand Down

0 comments on commit be6a817

Please sign in to comment.