diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 58c1bb7..2aa8730 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ Changelog - User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function. - User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function. +- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx. **Bug fixes** diff --git a/ndonnx/__init__.py b/ndonnx/__init__.py index 358b4b3..9587f99 100644 --- a/ndonnx/__init__.py +++ b/ndonnx/__init__.py @@ -14,6 +14,7 @@ Floating, Integral, Nullable, + NullableCore, NullableFloating, NullableIntegral, NullableNumerical, @@ -323,6 +324,7 @@ "Floating", "NullableIntegral", "Nullable", + "NullableCore", "Integral", "CoreType", "CastError", diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index 789cc73..f8c93c8 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -101,7 +101,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -111,7 +111,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 8ea8861..e01edd7 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -741,7 +741,7 @@ def clip( and isinstance(x.dtype, dtypes.Numerical) ): x, min, max = promote(x, min, max) - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): out_null = x.null x_values = x.values._core() clipped = from_corearray(opx.clip(x_values, min._core(), max._core())) @@ -847,7 +847,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -857,7 +857,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) @@ -889,7 +889,7 @@ def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Ar @validate_core def tril(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( @@ -900,7 +900,7 @@ def tril(self, x, k=0) -> ndx.Array: @validate_core def triu(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index 12e12c4..1ba2802 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -54,7 +54,7 @@ def zeros_like( self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None ): if dtype is not None and not isinstance( - dtype, (dtypes.CoreType, dtypes._NullableCore) + dtype, (dtypes.CoreType, dtypes.NullableCore) ): raise TypeError("'dtype' must be a CoreType or NullableCoreType") if dtype in (None, dtypes.utf8, dtypes.nutf8): diff --git a/ndonnx/_core/_utils.py b/ndonnx/_core/_utils.py index ec84943..29e1b74 100644 --- a/ndonnx/_core/_utils.py +++ b/ndonnx/_core/_utils.py @@ -38,7 +38,7 @@ def variadic_op( ): args = promote(*args) out_dtype = args[0].dtype - if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)): + if not isinstance(out_dtype, (dtypes.CoreType, dtypes.NullableCore)): raise TypeError( f"Expected ndx.Array with CoreType or NullableCoreType, got {args[0].dtype}" ) @@ -100,7 +100,7 @@ def _via_dtype( promoted = promote(*arrays) out_dtype = promoted[0].dtype - if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype: + if isinstance(out_dtype, dtypes.NullableCore) and out_dtype.values == dtype: dtype = out_dtype values, nulls = split_nulls_and_values( @@ -203,7 +203,7 @@ def validate_core(func): def wrapper(*args, **kwargs): for arg in itertools.chain(args, kwargs.values()): if isinstance(arg, ndx.Array) and not isinstance( - arg.dtype, (dtypes.CoreType, dtypes._NullableCore) + arg.dtype, (dtypes.CoreType, dtypes.NullableCore) ): return NotImplemented return func(*args, **kwargs) diff --git a/ndonnx/_data_types/__init__.py b/ndonnx/_data_types/__init__.py index bd9a58a..392abe0 100644 --- a/ndonnx/_data_types/__init__.py +++ b/ndonnx/_data_types/__init__.py @@ -40,7 +40,7 @@ NullableUnsigned, Numerical, Unsigned, - _NullableCore, + NullableCore, from_numpy_dtype, get_finfo, get_iinfo, @@ -51,7 +51,7 @@ from .structtype import StructType -def into_nullable(dtype: StructType | CoreType) -> _NullableCore: +def into_nullable(dtype: StructType | CoreType) -> NullableCore: """Return nullable counterpart, if present. Parameters @@ -61,7 +61,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: Returns ------- - out : _NullableCore + out : NullableCore The nullable counterpart of the input type. Raises @@ -93,7 +93,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: return nuint64 elif dtype == utf8: return nutf8 - elif isinstance(dtype, _NullableCore): + elif isinstance(dtype, NullableCore): return dtype else: raise ValueError(f"Cannot promote {dtype} to nullable") @@ -106,14 +106,14 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: "Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. " "To create nullable array, use 'ndonnx.additional.make_nullable' instead." ) -def promote_nullable(dtype: StructType | CoreType) -> _NullableCore: +def promote_nullable(dtype: StructType | CoreType) -> NullableCore: return into_nullable(dtype) __all__ = [ "CoreType", "StructType", - "_NullableCore", + "NullableCore", "NullableFloating", "NullableIntegral", "NullableUnsigned", diff --git a/ndonnx/_data_types/classes.py b/ndonnx/_data_types/classes.py index 661ef20..c8acd7e 100644 --- a/ndonnx/_data_types/classes.py +++ b/ndonnx/_data_types/classes.py @@ -189,7 +189,7 @@ def _fields(self) -> dict[str, StructType | CoreType]: } -class _NullableCore(Nullable[CoreType], CastMixin): +class NullableCore(Nullable[CoreType], CastMixin): def copy(self) -> Self: return self @@ -213,7 +213,7 @@ def _schema(self) -> Schema: return Schema(type_name=type(self).__name__, author="ndonnx") def _cast_to(self, array: Array, dtype: CoreType | StructType) -> Array: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): return ndx.Array._from_fields( dtype, values=self.values._cast_to(array.values, dtype.values), @@ -230,7 +230,7 @@ def _cast_from(self, array: Array) -> Array: values=self.values._cast_from(array), null=ndx.zeros_like(array, dtype=Boolean()), ) - elif isinstance(array.dtype, _NullableCore): + elif isinstance(array.dtype, NullableCore): return ndx.Array._from_fields( self, values=self.values._cast_from(array.values), @@ -240,7 +240,7 @@ def _cast_from(self, array: Array) -> Array: raise CastError(f"Cannot cast from {array.dtype} to {self}") -class NullableNumerical(_NullableCore): +class NullableNumerical(NullableCore): """Base class for nullable numerical data types.""" _ops: OperationsBlock = NullableNumericOperationsImpl() @@ -312,14 +312,14 @@ class NFloat64(NullableFloating): null = Boolean() -class NBoolean(_NullableCore): +class NBoolean(NullableCore): values = Boolean() null = Boolean() _ops: OperationsBlock = NullableBooleanOperationsImpl() -class NUtf8(_NullableCore): +class NUtf8(NullableCore): values = Utf8() null = Boolean() @@ -405,18 +405,18 @@ def _from_dtype(cls, dtype: CoreType) -> Finfo: ) -def get_finfo(dtype: _NullableCore | CoreType) -> Finfo: +def get_finfo(dtype: NullableCore | CoreType) -> Finfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Finfo._from_dtype(dtype) except KeyError: raise TypeError(f"'{dtype}' is not a floating point data type.") -def get_iinfo(dtype: _NullableCore | CoreType) -> Iinfo: +def get_iinfo(dtype: NullableCore | CoreType) -> Iinfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Iinfo._from_dtype(dtype) except KeyError: diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 3b4d775..d15dd16 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -11,7 +11,7 @@ import numpy.typing as npt import ndonnx._data_types as dtypes -from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore +from ndonnx._data_types import CastError, CastMixin, CoreType, NullableCore from ndonnx._data_types.structtype import StructType from ndonnx.additional import shape @@ -297,7 +297,7 @@ def result_type( np_dtypes = [] for dtype in observed_dtypes: if isinstance(dtype, dtypes.StructType): - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): nullable = True np_dtypes.append(dtype.values.to_numpy_dtype()) else: @@ -586,7 +586,11 @@ def numeric_like(x): def broadcast_to(x, shape): - return x.dtype._ops.broadcast_to(x, shape) + if (out := x.dtype._ops.broadcast_to(x, shape)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for broadcast_to: '{x.dtype}'" + ) # TODO: onnxruntime doesn't work for 2 empty arrays of integer type @@ -605,27 +609,47 @@ def concat(arrays, /, *, axis: int | None = 0): def expand_dims(x, axis=0): - return x.dtype._ops.expand_dims(x, axis) + if (out := x.dtype._ops.expand_dims(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for expand_dims: '{x.dtype}'" + ) def flip(x, axis=None): - return x.dtype._ops.flip(x, axis=axis) + if (out := x.dtype._ops.flip(x, axis=axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for flip: '{x.dtype}'") def permute_dims(x, axes): - return x.dtype._ops.permute_dims(x, axes) + if (out := x.dtype._ops.permute_dims(x, axes)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for permute_dims: '{x.dtype}'" + ) def reshape(x, shape, *, copy=None): - return x.dtype._ops.reshape(x, shape, copy=copy) + if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for reshape: '{x.dtype}'" + ) def roll(x, shift, axis=None): - return x.dtype._ops.roll(x, shift, axis) + if (out := x.dtype._ops.roll(x, shift, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for roll: '{x.dtype}'") def squeeze(x, axis): - return x.dtype._ops.squeeze(x, axis) + if (out := x.dtype._ops.squeeze(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for squeeze: '{x.dtype}'" + ) def stack(arrays, axis=0):