Skip to content

Commit

Permalink
Implement __eq__ and more date time related things
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 5, 2024
1 parent 1d0ef5f commit cdf3d12
Showing 10 changed files with 143 additions and 23 deletions.
4 changes: 4 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

from __future__ import annotations

import operator
from collections.abc import Callable, Sequence
from types import NotImplementedType
from typing import TYPE_CHECKING, TypeGuard, TypeVar
@@ -91,6 +92,9 @@ def as_core_dtype(self, dtype: CoreDTypes) -> TyArray:
def _astype(self, dtype: DType) -> TyArrayBase:
return NotImplemented

def _eqcomp(self, other) -> TyArrayBase:
return _promote_and_apply_op(self, other, operator.eq, op.equal)

def _where(
self, cond: TyArrayBool, y: TyArrayBase
) -> TyArrayBase | NotImplementedType:
27 changes: 23 additions & 4 deletions ndonnx/_logic_in_data/_typed_array/date_time.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import numpy as np

from ..dtypes import CoreIntegerDTypes, DType, bool_, int64
from .core import TyArrayBool, TyArrayInt64
from .core import TyArrayBool, TyArrayInt64, TyArrayInteger
from .funcs import astypedarray, typed_where
from .py_scalars import _ArrayPyInt
from .typed_array import TyArrayBase
@@ -25,6 +25,8 @@

Unit = Literal["ns", "s"]

_NAT_SENTINEL = _ArrayPyInt(np.iinfo(np.int64).min)


class DateTime(DType):
def __init__(self, unit: Unit):
@@ -38,6 +40,14 @@ def _result_type(self, other: DType) -> DType | NotImplementedType:
def _tyarr_class(self) -> type[TyArrayDateTime]:
return TyArrayDateTime

def _tyarray_from_tyarray(self, arr: TyArrayBase) -> TyArrayDateTime:
if isinstance(arr, TyArrayInteger):
data = safe_cast(TyArrayInt64, arr.astype(int64))
is_nat = safe_cast(TyArrayBool, data == _NAT_SENTINEL)
return TyArrayDateTime(is_nat=is_nat, data=data, unit=self.unit)

raise NotImplementedError

def __repr__(self) -> str:
return f"{self.__class__.__name__}[{self.unit}]"

@@ -54,6 +64,9 @@ def _result_type(self, other: DType) -> DType | NotImplementedType:
def _tyarr_class(self) -> type[TyArrayTimeDelta]:
return TyArrayTimeDelta

def _tyarray_from_tyarray(self, arr: TyArrayBase) -> Self:
raise NotImplementedError

def __repr__(self) -> str:
return f"{self.__class__.__name__}[{self.unit}]"

@@ -81,6 +94,11 @@ def __getitem__(self, index: Index) -> Self:
def from_typed_array(cls, tyarr: TyArrayBase):
if isinstance(tyarr, cls):
return cls(is_nat=tyarr.is_nat, data=tyarr.data, unit=tyarr.dtype.unit)
if isinstance(tyarr, TyArrayInteger):
data = safe_cast(TyArrayInt64, tyarr.astype(int64))
is_nat = safe_cast(TyArrayBool, data == _NAT_SENTINEL)
# PROBLEM: UNIT IS NOT KNOWN HERE!
return cls(is_nat=is_nat, data=data, unit=tyarr.dtype.unit)

return NotImplemented

@@ -94,6 +112,9 @@ def reshape(self, shape: tuple[int, ...]) -> Self:

return type(self)(is_nat=is_nat, data=data, unit=self.dtype.unit)

def _eqcomp(self, other) -> TyArrayBase:
raise NotImplementedError


class TyArrayTimeDelta(TimeBaseArray[TimeDelta]):
def __init__(self, is_nat: TyArrayBool, data: TyArrayInt64, unit: Unit):
@@ -113,9 +134,7 @@ def as_argument(cls, shape: OnnxShape, dtype: DType) -> Self:

def _astype(self, dtype: DType) -> TyArrayBase | NotImplementedType:
if isinstance(dtype, CoreIntegerDTypes):
data = typed_where(
self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data
)
data = typed_where(self.is_nat, _NAT_SENTINEL, self.data)
return data.astype(dtype)
if isinstance(dtype, TimeDelta):
powers = {
4 changes: 2 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/funcs.py
Original file line number Diff line number Diff line change
@@ -18,9 +18,9 @@ def typed_where(cond: TyArrayBase, x: TyArrayBase, y: TyArrayBase) -> TyArrayBas
raise TypeError("'cond' must be a boolean data type.")

ret = x._where(cond, y)
if ret == NotImplemented:
if ret is NotImplemented:
ret = y._rwhere(cond, x)
if ret == NotImplemented:
if ret is NotImplemented:
raise TypeError(
f"Unsuppoerted operand data types for 'where': `{x.dtype}` and `{y.dtype}`"
)
4 changes: 4 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/masked.py
Original file line number Diff line number Diff line change
@@ -155,6 +155,10 @@ def __radd__(self, lhs: TyArrayBase) -> TyMaArray:

return NotImplemented

def _eqcomp(self, other: TyArrayBase) -> TyArrayBase | NotImplementedType:
raise NotImplementedError()
...


class TyMaArrayNumber(TyMaArray[NCORE_NUMERIC_DTYPES]): ...

7 changes: 5 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/py_scalars.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
from .. import dtypes
from ..dtypes import DType
from .core import TyArray, TyArrayNumber
from .masked import TyMaArrayNumber
from .masked import TyMaArray, TyMaArrayNumber
from .typed_array import DTYPE, TyArrayBase
from .utils import promote, safe_cast

@@ -89,6 +89,9 @@ def _astype(self, dtype: DType) -> TyArrayBase:
return asncoredata(unmasked_typed_arr, None)
raise NotImplementedError

def _eqcomp(self, other) -> TyArrayBase:
return _promote_and_apply_op(self, other, operator.eq)

def where(self, cond: TyArrayBool, y: TyArrayBase) -> TyArrayBase:
raise NotImplementedError

@@ -110,7 +113,7 @@ def _promote_and_apply_op(
if isinstance(rhs, _ArrayPyScalar) != isinstance(lhs, _ArrayPyScalar):
lhs, rhs = promote(lhs, rhs)
# I am not sure how to annotate `safe_cast` such that it can handle union types.
return safe_cast(TyArrayNumber | TyMaArrayNumber, lhs + rhs) # type: ignore
return safe_cast(TyArray | TyMaArray, arr_op(lhs, rhs)) # type: ignore

# We only know about the other (nullable) built-in types &
# these scalars should never interact with themselves.
31 changes: 28 additions & 3 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
@@ -83,10 +83,10 @@ def astype(self, dtype: DType) -> TyArrayBase: ...
def astype(self, dtype: DType) -> TyArrayBase:
"""Convert `self` to the `_TypedArray` associated with `dtype`."""
res = self._astype(dtype)
if res == NotImplemented:
if res is NotImplemented:
# `type(self._data)` does not know about the target `dtype`
res = dtype._tyarr_class.from_typed_array(self)
if res != NotImplemented:
res = dtype._tyarray_from_tyarray(self)
if res is not NotImplemented:
return res
raise ValueError(f"casting between `{self.dtype}` and `{dtype}` is undefined")

@@ -116,12 +116,37 @@ def __add__(self, other: TyArrayBase) -> TyArrayBase:
def __and__(self, rhs: TyArrayBase) -> TyArrayBase:
return NotImplemented

# mypy believes that __eq__ should return a `bool` but the docs say we can return whatever:
# https://docs.python.org/3/reference/datamodel.html#object.__eq__
def __eq__(self, other) -> TyArrayBase: # type: ignore
res = self._eqcomp(other)
if res is NotImplemented:
res = other._eqcomp(self)
if res is NotImplemented:
raise ValueError(
f"comparison between `{type(self).__name__}` and `{type(other).__name__}` is not implemented."
)
return res

@abstractmethod
def _eqcomp(self, other: TyArrayBase) -> TyArrayBase | NotImplementedType:
"""Implementation of equal-comparison.
'__eq__' has special semantics compared to other dunder methods.
https://docs.python.org/3/reference/datamodel.html#object.__eq__
"""
...

def __invert__(self) -> TyArrayBase:
return NotImplemented

def __mul__(self, rhs: TyArrayBase) -> TyArrayBase:
return NotImplemented

def __ne__(self, other: TyArrayBase) -> TyArrayBase: # type: ignore
breakpoint()
return NotImplemented

def __or__(self, rhs: TyArrayBase) -> TyArrayBase:
return NotImplemented

6 changes: 4 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/utils.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,9 @@ def promote(lhs: TyArrayBase, *others: TyArrayBase) -> tuple[TyArrayBase[DType],
return tuple([lhs.astype(acc)] + [other.astype(acc) for other in others])


def safe_cast(ty: type[T], a: TyArrayBase) -> T:
def safe_cast(ty: type[T], a: TyArrayBase | bool) -> T:
# The union with bool is due to mypy thinking that __eq__ always
# returns bool.
if isinstance(a, ty):
return a
raise TypeError(f"Expected 'TyArrayInt64' found `{type(a)}`")
raise TypeError(f"Expected `{ty}` found `{type(a)}`")
2 changes: 1 addition & 1 deletion ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
@@ -124,7 +124,7 @@ def __rsub__(self, lhs: int | float | Array) -> Array:
return _apply_op(lhs, self, std_ops.sub)


def asarray(obj: int | float | bool | str | Array) -> Array:
def asarray(obj: int | float | bool | str | Array | np.ndarray) -> Array:
if isinstance(obj, Array):
return obj
data = ascoredata(op.const(obj))
74 changes: 66 additions & 8 deletions ndonnx/_logic_in_data/dtypes.py
Original file line number Diff line number Diff line change
@@ -6,22 +6,46 @@
from abc import ABC, abstractmethod
from functools import reduce
from types import NotImplementedType
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload

import numpy as np
from typing_extensions import Self
import spox.opset.ai.onnx.v21 as op

if TYPE_CHECKING:
from typing_extensions import Self

from . import _typed_array
from ._typed_array import core, masked


TY_ARRAY = TypeVar("TY_ARRAY", bound="_typed_array.TyArrayBase[Any]")
TY_ARRAY_CORE = TypeVar("TY_ARRAY_CORE", bound="core.TyArray[Any]")
TY_MA_ARRAY_CORE = TypeVar("TY_MA_ARRAY_CORE", bound="masked.TyMaArray[Any]")


class DType(ABC):
class DType(ABC, Generic[TY_ARRAY]):
@abstractmethod
def _result_type(self, other: DType) -> DType | NotImplementedType: ...

@property
@abstractmethod
def _tyarr_class(self) -> type[_typed_array.TyArrayBase[Self]]: ...
def _tyarr_class(self) -> type[TY_ARRAY]:
"""Consider using `_tyarray_from_tyarray` or `_argument` instead of.
Those functions better provide the dtype instance (with it's state) to the newly
instantiated array.
"""
...

@abstractmethod
def _tyarray_from_tyarray(self, arr: _typed_array.TyArrayBase) -> TY_ARRAY:
# replaces `TyArrayBase.from_typed_array`
...

# @abstractmethod
# def _argument(self, shape) -> _typed_array.TyArrayBase[Self]:
# # replaces `TyArrayBase.as_argument`
# ...

def __eq__(self, other) -> bool:
if type(self) is not type(other):
@@ -35,11 +59,37 @@ def __repr__(self) -> str:
return self.__class__.__name__


class _CoreDType(DType): ...
class _CoreDType(DType[TY_ARRAY_CORE]):
Foo: type[_typed_array.TyArrayBase[DType]]

@property
@abstractmethod
def _tyarr_class(self) -> type[TY_ARRAY_CORE]: ...

class _NCoreDType(DType):
_unmasked_dtype: _CoreDType
def _tyarray_from_tyarray(self, arr: _typed_array.TyArrayBase) -> TY_ARRAY_CORE:
from ._typed_array.core import TyArray

if isinstance(arr, TyArray):
var = op.cast(arr.var, to=as_numpy(self))
return self._tyarr_class(var)
raise NotImplementedError


class _NCoreDType(DType[TY_MA_ARRAY_CORE]):
_unmasked_dtype: CoreDTypes

def _tyarray_from_tyarray(self, arr: _typed_array.TyArrayBase) -> TY_MA_ARRAY_CORE:
from ._typed_array.core import TyArray, ascoredata
from ._typed_array.masked import TyMaArray

if isinstance(arr, TyArray):
data = ascoredata(op.cast(arr.var, to=as_numpy(self._unmasked_dtype)))
return self._tyarr_class(data=data, mask=None)
if isinstance(arr, TyMaArray):
mask = arr.mask
data_ = arr.data.astype(self._unmasked_dtype)
return self._tyarr_class(data=data_, mask=mask)
raise NotImplementedError


class _Number(_CoreDType):
@@ -177,6 +227,9 @@ def _tyarr_class(self) -> type[_typed_array._ArrayPyInt]:

return _ArrayPyInt

def _tyarray_from_tyarray(self, arr: _typed_array.TyArrayBase) -> Self:
raise NotImplementedError


class _PyFloat(DType):
def _result_type(self, other: DType) -> DType | NotImplementedType:
@@ -194,6 +247,9 @@ def _tyarr_class(self) -> type[_typed_array._ArrayPyFloat]:

return _ArrayPyFloat

def _tyarray_from_tyarray(self, arr: _typed_array.TyArrayBase) -> Self:
raise NotImplementedError


# Non-nullable Singleton instances
bool_ = Bool()
@@ -581,7 +637,7 @@ def from_numpy(np_dtype: np.dtype) -> CoreDTypes:
raise ValueError(f"'{np_dtype}' does not have a corresponding ndonnx data type")


def as_numpy(dtype: CoreDTypes) -> np.dtype:
def as_numpy(dtype: _CoreDType) -> np.dtype:
if dtype == int8:
return np.dtype("int8")
if dtype == int16:
@@ -616,3 +672,5 @@ def as_numpy(dtype: CoreDTypes) -> np.dtype:

default_int = int64
default_float = float64

DTYPE = TypeVar("DTYPE", bound=DType)
7 changes: 6 additions & 1 deletion tests/test_logic_in_data.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@

from ndonnx._logic_in_data import Array, dtypes
from ndonnx._logic_in_data._typed_array.date_time import DateTime, TimeDelta
from ndonnx._logic_in_data.array import where
from ndonnx._logic_in_data.array import asarray, where


@pytest.mark.parametrize(
@@ -115,3 +115,8 @@ def test_datetime():

res = arr + ten_s_td
assert res.dtype == DateTime("s")


def test_datetime_value_prop():
arr = asarray(np.asarray([1, 2])).astype(DateTime("s"))
np.testing.assert_equal(arr.to_numpy(), np.asarray([1, 2], dtype="datetime64[s]"))

0 comments on commit cdf3d12

Please sign in to comment.