diff --git a/ndonnx/_logic_in_data/_typed_array/date_time.py b/ndonnx/_logic_in_data/_typed_array/date_time.py index 659d24c..66bf529 100644 --- a/ndonnx/_logic_in_data/_typed_array/date_time.py +++ b/ndonnx/_logic_in_data/_typed_array/date_time.py @@ -13,7 +13,7 @@ from ..dtypes import TY_ARRAY, DType from ..schema import DTypeInfo, flatten_components from . import onnx, py_scalars -from .funcs import astypedarray, typed_where +from .funcs import astyarray, where from .typed_array import TyArrayBase from .utils import safe_cast @@ -183,7 +183,7 @@ def __init__(self, is_nat: onnx.TyArrayBool, data: onnx.TyArrayInt64, unit: Unit def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedType: res_type = dtype._tyarr_class if isinstance(dtype, onnx.CoreIntegerDTypes): - data = typed_where(self.is_nat, _NAT_SENTINEL, self.data) + data = where(self.is_nat, _NAT_SENTINEL, self.data) return data.astype(dtype) if isinstance(dtype, TimeDelta): powers = { @@ -193,9 +193,7 @@ def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedTyp "ns": 9, } power = powers[dtype.unit] - powers[self.dtype.unit] - data = typed_where( - self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data - ) + data = where(self.is_nat, astyarray(np.iinfo(np.int64).min), self.data) if power > 0: data = data * np.pow(10, power) @@ -243,9 +241,7 @@ def unwrap_numpy(self) -> np.ndarray: def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedType: res_type = dtype._tyarr_class if isinstance(dtype, onnx.CoreIntegerDTypes): - data = typed_where( - self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data - ) + data = where(self.is_nat, astyarray(np.iinfo(np.int64).min), self.data) return data.astype(dtype) if isinstance(dtype, DateTime): powers = { @@ -255,9 +251,7 @@ def __ndx_astype__(self, dtype: DType[TY_ARRAY]) -> TY_ARRAY | NotImplementedTyp "ns": 9, } power = powers[dtype.unit] - powers[self.dtype.unit] - data = typed_where( - self.is_nat, astypedarray(np.iinfo(np.int64).min), self.data - ) + data = where(self.is_nat, astyarray(np.iinfo(np.int64).min), self.data) if power > 0: data = data * np.pow(10, power) diff --git a/ndonnx/_logic_in_data/_typed_array/funcs.py b/ndonnx/_logic_in_data/_typed_array/funcs.py index 50e3b87..dee4e7d 100644 --- a/ndonnx/_logic_in_data/_typed_array/funcs.py +++ b/ndonnx/_logic_in_data/_typed_array/funcs.py @@ -14,24 +14,7 @@ from .utils import safe_cast -def typed_where(cond: TyArrayBase, x: TyArrayBase, y: TyArrayBase) -> TyArrayBase: - from . import TyArrayBool - - # TODO: Masked condition - if not isinstance(cond, TyArrayBool): - raise TypeError("'cond' must be a boolean data type.") - - ret = x.__ndx_where__(cond, y) - if ret is NotImplemented: - ret = y.__ndx_rwhere__(cond, x) - if ret is NotImplemented: - raise TypeError( - f"Unsupported operand data types for 'where': `{x.dtype}` and `{y.dtype}`" - ) - return ret - - -def astypedarray( +def astyarray( val: int | float | str | np.ndarray | TyArrayBase | Var, dtype: None | DType = None, use_py_scalars=False, @@ -69,6 +52,28 @@ def astypedarray( return arr +######################################################################### +# Free functions implemented via `__ndx_*__` methods on the typed array # +######################################################################### + + +def where(cond: TyArrayBase, x: TyArrayBase, y: TyArrayBase) -> TyArrayBase: + from . import TyArrayBool + + # TODO: Masked condition + if not isinstance(cond, TyArrayBool): + raise TypeError("'cond' must be a boolean data type.") + + ret = x.__ndx_where__(cond, y) + if ret is NotImplemented: + ret = y.__ndx_rwhere__(cond, x) + if ret is NotImplemented: + raise TypeError( + f"Unsupported operand data types for 'where': `{x.dtype}` and `{y.dtype}`" + ) + return ret + + def maximum(x1: TyArrayBase, x2: TyArrayBase, /) -> TyArrayBase: res = x1.__ndx_maximum__(x2) if res is NotImplemented: @@ -78,14 +83,3 @@ def maximum(x1: TyArrayBase, x2: TyArrayBase, /) -> TyArrayBase: f"Unsupported operand data types for 'max': `{x1.dtype}` and `{x2.dtype}`" ) return res - - -def sum( - x: TyArrayBase, - /, - *, - axis: int | tuple[int, ...] | None = None, - dtype: DType | None = None, - keepdims: bool = False, -) -> TyArrayBase: - return x.sum(axis=axis, dtype=dtype, keepdims=keepdims) diff --git a/ndonnx/_logic_in_data/_typed_array/masked_onnx.py b/ndonnx/_logic_in_data/_typed_array/masked_onnx.py index 31d6043..3024b77 100644 --- a/ndonnx/_logic_in_data/_typed_array/masked_onnx.py +++ b/ndonnx/_logic_in_data/_typed_array/masked_onnx.py @@ -14,7 +14,7 @@ from ..dtypes import TY_ARRAY, DType from ..schema import DTypeInfo, Schema, flatten_components from . import onnx -from .funcs import astypedarray +from .funcs import astyarray from .typed_array import TyArrayBase from .utils import safe_cast @@ -321,7 +321,7 @@ def __setitem__(self, index: SetitemIndex, value: Self) -> None: if self.mask is None: shape = self.dynamic_shape self.mask = safe_cast( - onnx.TyArrayBool, astypedarray(False).broadcast_to(shape) + onnx.TyArrayBool, astyarray(False).broadcast_to(shape) ) self.mask[index] = new_mask else: diff --git a/ndonnx/_logic_in_data/_typed_array/typed_array.py b/ndonnx/_logic_in_data/_typed_array/typed_array.py index 101f171..167eac8 100644 --- a/ndonnx/_logic_in_data/_typed_array/typed_array.py +++ b/ndonnx/_logic_in_data/_typed_array/typed_array.py @@ -31,6 +31,10 @@ def __repr__(self) -> str: @abstractmethod def __ndx_value_repr__(self) -> dict[str, str]: """A string representation of the fields to be used in ``Array.__repr__```.""" + # Note: It is unfortunate that this part of the API relies on + # the rather useless `dict[str, str]` type hint. `TypedDict` + # is not a viable solution (?) since it does not play nicely + # with the subtyping. @abstractmethod def __getitem__(self, index: GetitemIndex) -> Self: ... diff --git a/ndonnx/_logic_in_data/array.py b/ndonnx/_logic_in_data/array.py index 6c9e3a1..ad7d1ca 100644 --- a/ndonnx/_logic_in_data/array.py +++ b/ndonnx/_logic_in_data/array.py @@ -13,7 +13,7 @@ from spox import Var from ._typed_array import TyArrayBase -from ._typed_array.funcs import astypedarray +from ._typed_array.funcs import astyarray from .dtypes import DType StrictShape = tuple[int, ...] @@ -60,7 +60,7 @@ def __init__(self, shape=None, dtype=None, value=None, var=None): if isinstance(value, np.ndarray): raise NotImplementedError if isinstance(value, int | float): - ty_arr = astypedarray(value, use_py_scalars=False, dtype=dtype) + ty_arr = astyarray(value, use_py_scalars=False, dtype=dtype) self._data = ty_arr return @@ -278,7 +278,7 @@ def asarray( return obj if isinstance(obj, bool | int | float): obj = np.array(obj) - data = astypedarray(obj) + data = astyarray(obj) if dtype: data = data.astype(dtype) return Array._from_data(data) @@ -289,7 +289,7 @@ def _as_array( ) -> Array: if isinstance(val, Array): return val - ty_arr = astypedarray(val, use_py_scalars=use_py_scalars) + ty_arr = astyarray(val, use_py_scalars=use_py_scalars) return Array._from_data(ty_arr) diff --git a/ndonnx/_logic_in_data/funcs.py b/ndonnx/_logic_in_data/funcs.py index 9f9517e..ce71a12 100644 --- a/ndonnx/_logic_in_data/funcs.py +++ b/ndonnx/_logic_in_data/funcs.py @@ -168,15 +168,13 @@ def sum( dtype: DType | None = None, keepdims: bool = False, ) -> Array: - from ._typed_array.funcs import sum - - return Array._from_data(sum(x._data)) + return Array._from_data(x._data.sum(axis=axis, dtype=dtype, keepdims=keepdims)) def where(cond: Array, a: Array, b: Array) -> Array: - from ._typed_array.funcs import typed_where + from ._typed_array import funcs as tyfuncs - data = typed_where(cond._data, a._data, b._data) + data = tyfuncs.where(cond._data, a._data, b._data) return Array._from_data(data)