From bd23b014c13588ba0aa25f73606502d32af859ad Mon Sep 17 00:00:00 2001 From: Christian Bourjau Date: Mon, 2 Sep 2024 10:11:22 +0200 Subject: [PATCH] Implement where for nullable data types --- ndonnx/_logic_in_data/_typed_array/core.py | 50 +++++++++++++++---- ndonnx/_logic_in_data/_typed_array/funcs.py | 17 +++++++ ndonnx/_logic_in_data/_typed_array/masked.py | 42 ++++++++++++++-- .../_typed_array/typed_array.py | 17 ++++++- ndonnx/_logic_in_data/array.py | 16 ++---- tests/test_logic_in_data.py | 16 ++++-- 6 files changed, 126 insertions(+), 32 deletions(-) diff --git a/ndonnx/_logic_in_data/_typed_array/core.py b/ndonnx/_logic_in_data/_typed_array/core.py index 0ca26a3..43c8dc1 100644 --- a/ndonnx/_logic_in_data/_typed_array/core.py +++ b/ndonnx/_logic_in_data/_typed_array/core.py @@ -3,7 +3,8 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from types import NotImplementedType from typing import TYPE_CHECKING, Any, TypeGuard, TypeVar import numpy as np @@ -95,7 +96,9 @@ def as_core_dtype(self, dtype: CoreDTypes) -> _ArrayCoreType: def _astype(self, dtype: DType) -> _TypedArray: return NotImplemented - def where(self, cond: BoolData, y: _TypedArray) -> _TypedArray: + def _where( + self, cond: BoolData, y: _TypedArray + ) -> _TypedArray | NotImplementedType: if isinstance(y, _ArrayCoreType): x, y = promote(self, y) var = op.where(cond.var, x.var, y.var) @@ -136,18 +139,29 @@ class BoolData(_ArrayCoreType[dtypes.Bool]): dtype = dtypes.bool_ def __or__(self, rhs: _TypedArray) -> _TypedArray: - from .utils import promote + if self.dtype != rhs.dtype: + a, b = promote(self, rhs) + return a | b - if isinstance(rhs, _ArrayCoreType): - if self.dtype != rhs.dtype: - a, b = promote(self, rhs) - return a | b - - # Data is core & bool + if isinstance(rhs, BoolData): var = op.or_(self.var, rhs.var) return ascoredata(var) return NotImplemented + def __and__(self, rhs: _TypedArray) -> _TypedArray: + if self.dtype != rhs.dtype: + a, b = promote(self, rhs) + return a & b + + if isinstance(rhs, BoolData): + var = op.and_(self.var, rhs.var) + return ascoredata(var) + return NotImplemented + + def __invert__(self) -> BoolData: + var = op.not_(self.var) + return type(self)(var) + class Int8Data(_ArrayCoreInteger[dtypes.Int8]): dtype = dtypes.int8 @@ -203,3 +217,21 @@ def is_sequence_of_core_data( seq: Sequence[_TypedArray], ) -> TypeGuard[Sequence[_ArrayCoreType]]: return all(isinstance(d, _ArrayCoreType) for d in seq) + + +def _promote_and_apply_op( + lhs: _ArrayCoreType, + rhs: _TypedArray, + arr_op: Callable[[_ArrayCoreType, _ArrayCoreType], _ArrayCoreType], + spox_op: Callable[[Var, Var], Var], +) -> _ArrayCoreType: + """Promote and apply an operation by passing it through to the data member.""" + if isinstance(rhs, _ArrayCoreType): + if lhs.dtype != rhs.dtype: + a, b = promote(lhs, rhs) + return arr_op(a, b) + + # Data is core & integer + var = spox_op(lhs.var, rhs.var) + return ascoredata(var) + return NotImplemented diff --git a/ndonnx/_logic_in_data/_typed_array/funcs.py b/ndonnx/_logic_in_data/_typed_array/funcs.py index aedccec..8f07fe7 100644 --- a/ndonnx/_logic_in_data/_typed_array/funcs.py +++ b/ndonnx/_logic_in_data/_typed_array/funcs.py @@ -9,6 +9,23 @@ from .typed_array import _TypedArray +def typed_where(cond: _TypedArray, x: _TypedArray, y: _TypedArray) -> _TypedArray: + from .core import BoolData + + # TODO: Masked condition + if not isinstance(cond, BoolData): + raise TypeError("'cond' must be a boolean data type.") + + ret = x._where(cond, y) + if ret == NotImplemented: + ret = y._rwhere(cond, x) + if ret == NotImplemented: + raise TypeError( + f"Unsuppoerted operand data types for 'where': `{x.dtype}` and `{y.dtype}`" + ) + return ret + + def astypedarray( val: int | float | np.ndarray | _TypedArray | Var, dtype: None | DType = None, diff --git a/ndonnx/_logic_in_data/_typed_array/masked.py b/ndonnx/_logic_in_data/_typed_array/masked.py index b636e74..8f7c8a3 100644 --- a/ndonnx/_logic_in_data/_typed_array/masked.py +++ b/ndonnx/_logic_in_data/_typed_array/masked.py @@ -6,6 +6,7 @@ import operator from collections.abc import Callable from dataclasses import dataclass +from types import NotImplementedType from typing import TYPE_CHECKING, Any, TypeVar import spox.opset.ai.onnx.v21 as op @@ -111,9 +112,40 @@ def _astype(self, dtype: DType) -> _TypedArray: dtype._tyarr_class(data=new_data, mask=self.mask) return NotImplemented - def where(self, cond: BoolData, y: _TypedArray) -> _TypedArray: - # TODO - raise NotImplementedError + def _where( + self, cond: BoolData, y: _TypedArray + ) -> _TypedArray | NotImplementedType: + if isinstance(y, _ArrayCoreType): + return self._where(cond, asncoredata(y, None)) + if isinstance(y, _ArrayMaCoreType): + x_ = unmask_core(self) + y_ = unmask_core(y) + new_data = x_._where(cond, y_) + if self.mask is not None and y.mask is not None: + new_mask = cond & self.mask | ~cond & y.mask + elif self.mask is not None: + new_mask = cond & self.mask + elif y.mask is not None: + new_mask = ~cond & y.mask + else: + new_mask = None + + if new_mask is not None and not isinstance(new_mask, BoolData): + # Should never happen. Might be worth while adding + # overloads to the BoolData dunder methods to + # propagate the types more precisely. + raise TypeError(f"expected boolean mask, found `{new_mask.dtype}`") + + return asncoredata(new_data, new_mask) + + return NotImplemented + + def _rwhere( + self, cond: BoolData, x: _TypedArray + ) -> _TypedArray | NotImplementedType: + if isinstance(x, _ArrayCoreType): + return asncoredata(x, None)._where(cond, self) + return NotImplemented class NBoolData(_ArrayMaCoreType[dtypes.NBool]): @@ -229,10 +261,10 @@ def _apply_op( ) -> _ArrayMaCoreType: """Apply an operation by passing it through to the data member.""" if isinstance(rhs, _ArrayCoreType): - data = lhs.data + rhs + data = op(lhs.data, rhs) mask = lhs.mask elif isinstance(rhs, _ArrayMaCoreType): - data = lhs.data + rhs.data + data = op(lhs.data, rhs.data) mask = _merge_masks(lhs.mask, rhs.mask) else: return NotImplemented diff --git a/ndonnx/_logic_in_data/_typed_array/typed_array.py b/ndonnx/_logic_in_data/_typed_array/typed_array.py index 634b87c..f934fb1 100644 --- a/ndonnx/_logic_in_data/_typed_array/typed_array.py +++ b/ndonnx/_logic_in_data/_typed_array/typed_array.py @@ -96,11 +96,24 @@ def _astype(self, dtype: DType) -> _TypedArray | NotImplementedType: """ return NotImplemented - @abstractmethod - def where(self, cond: BoolData, y: _TypedArray) -> _TypedArray: ... + def _where( + self, cond: BoolData, y: _TypedArray + ) -> _TypedArray | NotImplementedType: + return NotImplemented + + def _rwhere( + self, cond: BoolData, y: _TypedArray + ) -> _TypedArray | NotImplementedType: + return NotImplemented def __add__(self, other: _TypedArray) -> _TypedArray: return NotImplemented + def __and__(self, rhs: _TypedArray) -> _TypedArray: + return NotImplemented + + def __invert__(self) -> _TypedArray: + return NotImplemented + def __or__(self, rhs: _TypedArray) -> _TypedArray: return NotImplemented diff --git a/ndonnx/_logic_in_data/array.py b/ndonnx/_logic_in_data/array.py index 89dddbd..0c0a20c 100644 --- a/ndonnx/_logic_in_data/array.py +++ b/ndonnx/_logic_in_data/array.py @@ -62,6 +62,8 @@ def __init__(self, shape=None, dtype=None, value=None, var=None): @classmethod def _from_data(cls, data: _TypedArray) -> Array: + if not isinstance(data, _TypedArray): + raise TypeError(f"expected '_TypedArray', found `{type(data)}`") inst = cls.__new__(cls) inst._data = data return inst @@ -117,19 +119,9 @@ def asarray(obj: int | float | bool | str | Array) -> Array: def where(cond: Array, a: Array, b: Array) -> Array: - from ._typed_array import BoolData - from .dtypes import bool_, nbool + from ._typed_array.funcs import typed_where - if cond.dtype not in [bool_, nbool]: - raise ValueError - - # TODO: NBoolData - if not isinstance(cond._data, BoolData): - raise ValueError( - f"condition must be of a boolean data type; found `{cond.dtype}`" - ) - - data = a._data.where(cond._data, b._data) + data = typed_where(cond._data, a._data, b._data) return Array._from_data(data) diff --git a/tests/test_logic_in_data.py b/tests/test_logic_in_data.py index 338f734..1fdc1d4 100644 --- a/tests/test_logic_in_data.py +++ b/tests/test_logic_in_data.py @@ -77,14 +77,22 @@ def test__getitem__(): assert arr[0].shape == (None,) -def test_where(): +@pytest.mark.parametrize( + "x_ty, y_ty, res_ty", + [ + (dtypes.int16, dtypes.int32, dtypes.int32), + (dtypes.nint16, dtypes.int32, dtypes.nint32), + (dtypes.int32, dtypes.nint16, dtypes.nint32), + ], +) +def test_where(x_ty, y_ty, res_ty): shape = ("N", "M") cond = Array(shape, dtypes.bool_) - x = Array(shape, dtypes.int16) - y = Array(shape, dtypes.int32) + x = Array(shape, x_ty) + y = Array(shape, y_ty) res = where(cond, x, y) - assert res.dtype == dtypes.int32 + assert res.dtype == res_ty assert res._data.shape == shape assert res.shape == (None, None)