Skip to content

Commit

Permalink
Passing datetime test
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 5, 2024
1 parent 2406c9b commit 1d0ef5f
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 33 deletions.
10 changes: 10 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ def __add__(self, rhs: TyArrayBase) -> TyArrayBase:
return ascoredata(var)
return NotImplemented

def __mul__(self, rhs: TyArrayBase) -> TyArrayBase:
if isinstance(rhs, TyArrayNumber):
# NOTE: Can't always promote for all data types (c.f. datetime / timedelta)
if type(self) != type(rhs):
a, b = promote(self, rhs)
return a * b
var = op.mul(self.var, rhs.var)
return ascoredata(var)
return NotImplemented

def __sub__(self, rhs: TyArrayBase) -> TyArrayBase:
if isinstance(rhs, TyArrayNumber):
# NOTE: Can't always promote for all data types (c.f. datetime / timedelta)
Expand Down
27 changes: 16 additions & 11 deletions ndonnx/_logic_in_data/_typed_array/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .funcs import astypedarray, typed_where
from .py_scalars import _ArrayPyInt
from .typed_array import TyArrayBase
from .utils import safe_cast

if TYPE_CHECKING:
from types import NotImplementedType
Expand All @@ -24,14 +25,6 @@

Unit = Literal["ns", "s"]

T = TypeVar("T")


def safe_cast(ty: type[T], a: TyArrayBase) -> T:
if isinstance(a, ty):
return a
raise TypeError(f"Expected 'TyArrayInt64' found `{type(a)}`")


class DateTime(DType):
def __init__(self, unit: Unit):
Expand Down Expand Up @@ -144,6 +137,12 @@ def _astype(self, dtype: DType) -> TyArrayBase | NotImplementedType:
return data
return NotImplemented

def __mul__(self, rhs: TyArrayBase) -> TyArrayTimeDelta | TyArrayTimeDelta:
if isinstance(rhs, _ArrayPyInt):
data = cast(TyArrayInt64, (self.data * rhs))
return type(self)(is_nat=self.is_nat, data=data, unit=self.dtype.unit)
raise NotImplementedError


class TyArrayDateTime(TimeBaseArray[DateTime]):
def __init__(self, is_nat: TyArrayBool, data: TyArrayInt64, unit: Unit):
Expand Down Expand Up @@ -196,10 +195,16 @@ def _astype(self, dtype: DType) -> TyArrayBase | NotImplementedType:
return NotImplemented

def __add__(self, rhs: TyArrayBase) -> TyArrayDateTime | TyArrayTimeDelta:
rhs_data: _ArrayPyInt | TyArrayInt64
if isinstance(rhs, _ArrayPyInt):
data = cast(TyArrayInt64, (self.data + rhs))
return type(self)(is_nat=self.is_nat, data=data, unit=self.dtype.unit)
raise NotImplementedError
rhs_data = rhs
elif isinstance(rhs, TyArrayTimeDelta):
rhs_data = rhs.data
else:
raise NotImplementedError

data = cast(TyArrayInt64, (self.data + rhs_data))
return type(self)(is_nat=self.is_nat, data=data, unit=self.dtype.unit)

def __sub__(self, rhs: TyArrayBase) -> TyArrayDateTime | TyArrayTimeDelta:
if isinstance(rhs, _ArrayPyInt):
Expand Down
40 changes: 26 additions & 14 deletions ndonnx/_logic_in_data/_typed_array/py_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

import operator
from collections.abc import Callable
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -13,7 +15,7 @@
from .core import TyArray, TyArrayNumber
from .masked import TyMaArrayNumber
from .typed_array import DTYPE, TyArrayBase
from .utils import promote
from .utils import promote, safe_cast

if TYPE_CHECKING:
from ..array import OnnxShape
Expand Down Expand Up @@ -59,22 +61,16 @@ def reshape(self, shape: tuple[int, ...]) -> Self:
raise ValueError("cannot reshape Python scalar")

def __add__(self, rhs: TyArrayBase[DType]) -> TyArrayBase[DType]:
if isinstance(rhs, TyArrayNumber | TyMaArrayNumber):
lhs, rhs = promote(self, rhs)
return lhs + rhs

# We only know about the other (nullable) built-in types &
# these scalars should never interact with themselves.
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.add)

def __radd__(self, lhs: TyArrayBase[DType]) -> TyArrayBase[DType]:
if isinstance(lhs, TyArrayNumber | TyMaArrayNumber):
lhs, rhs = promote(lhs, self)
return lhs + rhs
return _promote_and_apply_op(lhs, self, operator.add)

# We only know about the other (nullable) built-in types &
# these scalars should never interact with themselves.
return NotImplemented
def __mul__(self, rhs: TyArrayBase[DType]) -> TyArrayBase[DType]:
return _promote_and_apply_op(self, rhs, operator.mul)

def __rmul__(self, lhs: TyArrayBase[DType]) -> TyArrayBase[DType]:
return _promote_and_apply_op(lhs, self, operator.mul)

def __or__(self, rhs: TyArrayBase) -> TyArray:
return NotImplemented
Expand Down Expand Up @@ -103,3 +99,19 @@ class _ArrayPyInt(_ArrayPyScalar[dtypes._PyInt]):

class _ArrayPyFloat(_ArrayPyScalar[dtypes._PyFloat]):
dtype = dtypes._pyfloat


def _promote_and_apply_op(
lhs: TyArrayBase | _ArrayPyScalar,
rhs: TyArrayBase | _ArrayPyScalar,
arr_op: Callable[[TyArray, TyArray], TyArray],
) -> TyArrayNumber | TyMaArrayNumber:
# Must be xor for the two unions
if isinstance(rhs, _ArrayPyScalar) != isinstance(lhs, _ArrayPyScalar):
lhs, rhs = promote(lhs, rhs)
# I am not sure how to annotate `safe_cast` such that it can handle union types.
return safe_cast(TyArrayNumber | TyMaArrayNumber, lhs + rhs) # type: ignore

# We only know about the other (nullable) built-in types &
# these scalars should never interact with themselves.
return NotImplemented
3 changes: 3 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __and__(self, rhs: TyArrayBase) -> TyArrayBase:
def __invert__(self) -> TyArrayBase:
return NotImplemented

def __mul__(self, rhs: TyArrayBase) -> TyArrayBase:
return NotImplemented

def __or__(self, rhs: TyArrayBase) -> TyArrayBase:
return NotImplemented

Expand Down
10 changes: 9 additions & 1 deletion ndonnx/_logic_in_data/_typed_array/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from __future__ import annotations

from types import NotImplementedType
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, TypeVar, overload

if TYPE_CHECKING:
from ..dtypes import DType
from .core import TyArray
from .typed_array import TyArrayBase

T = TypeVar("T")


@overload
def promote(lhs: TyArray, *others: TyArray) -> tuple[TyArray, ...]: ...
Expand All @@ -32,3 +34,9 @@ def promote(lhs: TyArrayBase, *others: TyArrayBase) -> tuple[TyArrayBase[DType],
raise TypeError("Failed to promote into common data type.")
acc = updated
return tuple([lhs.astype(acc)] + [other.astype(acc) for other in others])


def safe_cast(ty: type[T], a: TyArrayBase) -> T:
if isinstance(a, ty):
return a
raise TypeError(f"Expected 'TyArrayInt64' found `{type(a)}`")
22 changes: 15 additions & 7 deletions ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,36 @@ def __getitem__(self, index: Index) -> Array:
data = self._data[index]
return type(self)._from_data(data)

# __r*__ are needed for interacting with Python scalars
# (e.g. doing 1 + Array(...)). These functions are _NOT_ used to
# dispatch between different `_TypedArray` subclasses.
##################################################################
# __r*__ are needed for interacting with Python scalars #
# (e.g. doing 1 + Array(...)). These functions are _NOT_ used to #
# dispatch between different `_TypedArray` subclasses. #
##################################################################

def __add__(self, rhs: int | float | Array) -> Array:
return _apply_op(self, rhs, std_ops.add)

def __radd__(self, lhs: int | float | Array) -> Array:
return _apply_op(lhs, self, std_ops.add)

def __sub__(self, rhs: int | float | Array) -> Array:
return _apply_op(self, rhs, std_ops.sub)
def __mul__(self, rhs: int | float | Array) -> Array:
return _apply_op(self, rhs, std_ops.mul)

def __rsub__(self, lhs: int | float | Array) -> Array:
return _apply_op(lhs, self, std_ops.sub)
def __rmul__(self, lhs: int | float | Array) -> Array:
return _apply_op(lhs, self, std_ops.mul)

def __or__(self, rhs: int | float | Array) -> Array:
return _apply_op(self, rhs, std_ops.or_)

def __ror__(self, lhs: int | float | Array) -> Array:
return _apply_op(lhs, self, std_ops.or_)

def __sub__(self, rhs: int | float | Array) -> Array:
return _apply_op(self, rhs, std_ops.sub)

def __rsub__(self, lhs: int | float | Array) -> Array:
return _apply_op(lhs, self, std_ops.sub)


def asarray(obj: int | float | bool | str | Array) -> Array:
if isinstance(obj, Array):
Expand Down

0 comments on commit 1d0ef5f

Please sign in to comment.