diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9de8604..2aa8730 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,12 @@ Changelog 0.9.0 (unreleased) ------------------ +**New features** + +- User defined data types can now define how arrays with that dtype are constructed by implementing the ``make_array`` function. +- User defined data types can now define how they are indexed (via ``__getitem__``) by implementing the ``getitem`` function. +- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx. + **Bug fixes** - Various operations that depend on the array's shape have been updated to work correctly with lazy arrays. diff --git a/ndonnx/__init__.py b/ndonnx/__init__.py index 358b4b3..9587f99 100644 --- a/ndonnx/__init__.py +++ b/ndonnx/__init__.py @@ -14,6 +14,7 @@ Floating, Integral, Nullable, + NullableCore, NullableFloating, NullableIntegral, NullableNumerical, @@ -323,6 +324,7 @@ "Floating", "NullableIntegral", "Nullable", + "NullableCore", "Integral", "CoreType", "CastError", diff --git a/ndonnx/_array.py b/ndonnx/_array.py index 08195a7..c06bf66 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -15,6 +15,7 @@ import ndonnx as ndx import ndonnx._data_types as dtypes from ndonnx.additional import shape +from ndonnx.additional._additional import _getitem as getitem from ndonnx.additional._additional import _static_shape as static_shape from ._corearray import _CoreArray @@ -47,7 +48,11 @@ def array( out : Array The new array. This represents an ONNX model input. """ - return Array._construct(shape=shape, dtype=dtype) + if (out := dtype._ops.make_array(shape, dtype)) is not NotImplemented: + return out + raise ndx.UnsupportedOperationError( + f"No implementation of `make_array` for {dtype}" + ) def from_spox_var( @@ -154,17 +159,7 @@ def astype(self, to: CoreType | StructType) -> Array: return ndx.astype(self, to) def __getitem__(self, index: IndexType) -> Array: - if isinstance(index, Array) and not ( - isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool - ): - raise TypeError( - f"Index must be an integral or boolean 'Array', not `{index.dtype}`" - ) - - if isinstance(index, Array): - index = index._core() - - return self._transmute(lambda corearray: corearray[index]) + return getitem(self, index) def __setitem__( self, index: IndexType | Self, updates: int | bool | float | Array diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index 7c46e91..1797a58 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -12,8 +12,9 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx -import ndonnx.additional as nda +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations from ._utils import binary_op, unary_op, validate_core @@ -22,7 +23,7 @@ from ndonnx import Array -class BooleanOperationsImpl(UniformShapeOperations): +class _BooleanOperationsImpl(OperationsBlock): @validate_core def equal(self, x, y) -> Array: return binary_op(x, y, opx.equal) @@ -100,7 +101,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -110,7 +111,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) @@ -163,16 +164,6 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array: def nonzero(self, x) -> tuple[Array, ...]: return ndx.nonzero(x.astype(ndx.int8)) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.broadcast_to(null, nda.shape(x)), - ) - @validate_core def where(self, condition, x, y): if x.dtype != y.dtype: @@ -182,6 +173,11 @@ def where(self, condition, x, y): return super().where(condition, x, y) -class NullableBooleanOperationsImpl(BooleanOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class BooleanOperationsImpl( + CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations +): ... + + +class NullableBooleanOperationsImpl( + NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations +): ... diff --git a/ndonnx/_core/_coreimpl.py b/ndonnx/_core/_coreimpl.py new file mode 100644 index 0000000..e5b2a58 --- /dev/null +++ b/ndonnx/_core/_coreimpl.py @@ -0,0 +1,50 @@ +# Copyright (c) QuantCo 2023-2024 +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from spox import Tensor, argument + +import ndonnx as ndx +import ndonnx._data_types as dtypes +import ndonnx.additional as nda +from ndonnx._corearray import _CoreArray + +from ._interface import OperationsBlock +from ._utils import validate_core + +if TYPE_CHECKING: + from ndonnx._array import Array + from ndonnx._data_types import Dtype + + +class CoreOperationsImpl(OperationsBlock): + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if not isinstance(dtype, dtypes.CoreType): + return NotImplemented + return ndx.Array._from_fields( + dtype, + data=_CoreArray( + dtype._parse_input(eager_value)["data"] + if eager_value is not None + else argument(Tensor(dtype.to_numpy_dtype(), shape)) + ), + ) + + @validate_core + def make_nullable(self, x: Array, null: Array) -> Array: + if null.dtype != ndx.bool: + raise TypeError("'null' must be a boolean array") + + return ndx.Array._from_fields( + dtypes.into_nullable(x.dtype), + values=x.copy(), + null=ndx.broadcast_to(null, nda.shape(x)), + ) diff --git a/ndonnx/_core/_interface.py b/ndonnx/_core/_interface.py index 37fca9a..5340f4f 100644 --- a/ndonnx/_core/_interface.py +++ b/ndonnx/_core/_interface.py @@ -3,11 +3,17 @@ from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING, Literal + +import numpy as np import ndonnx as ndx import ndonnx._data_types as dtypes +if TYPE_CHECKING: + from ndonnx._array import IndexType + from ndonnx._data_types import Dtype + class OperationsBlock: """Interface for data types to implement top-level functions exported by ndonnx.""" @@ -251,7 +257,7 @@ def cumulative_sum( x, *, axis: int | None = None, - dtype: ndx.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, include_initial: bool = False, ): return NotImplemented @@ -270,7 +276,7 @@ def prod( x, *, axis=None, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, keepdims: bool = False, ) -> ndx.Array: return NotImplemented @@ -293,7 +299,7 @@ def sum( x, *, axis=None, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, keepdims: bool = False, ) -> ndx.Array: return NotImplemented @@ -305,7 +311,7 @@ def var( axis=None, keepdims: bool = False, correction=0.0, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, ) -> ndx.Array: return NotImplemented @@ -352,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array: def ones( self, shape, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, device=None, ): return NotImplemented @@ -365,14 +371,12 @@ def ones_like( def zeros( self, shape, - dtype: dtypes.CoreType | dtypes.StructType | None = None, + dtype: Dtype | None = None, device=None, ): return NotImplemented - def zeros_like( - self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None - ): + def zeros_like(self, x, dtype: Dtype | None = None, device=None): return NotImplemented def empty(self, shape, dtype=None, device=None) -> ndx.Array: @@ -413,3 +417,18 @@ def can_cast(self, from_, to) -> bool: def static_shape(self, x) -> tuple[int | None, ...]: return NotImplemented + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> ndx.Array: + return NotImplemented + + def getitem( + self, + x: ndx.Array, + index: IndexType, + ) -> ndx.Array: + return NotImplemented diff --git a/ndonnx/_core/_nullableimpl.py b/ndonnx/_core/_nullableimpl.py index 71115fc..ce1013b 100644 --- a/ndonnx/_core/_nullableimpl.py +++ b/ndonnx/_core/_nullableimpl.py @@ -1,16 +1,29 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +from typing import TYPE_CHECKING, Union import ndonnx as ndx from ._interface import OperationsBlock from ._utils import validate_core +if TYPE_CHECKING: + from ndonnx._array import Array + from ndonnx._data_types import CoreType, StructType + + Dtype = Union[CoreType, StructType] + class NullableOperationsImpl(OperationsBlock): @validate_core - def fill_null(self, x, value): + def fill_null(self, x: Array, value) -> Array: value = ndx.asarray(value) if value.dtype != x.values.dtype: value = value.astype(x.values.dtype) return ndx.where(x.null, value, x.values) + + @validate_core + def make_nullable(self, x: Array, null: Array) -> Array: + return NotImplemented diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 428bbf7..75dd098 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -19,6 +19,8 @@ import ndonnx.additional as nda from ndonnx._utility import promote +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations from ._utils import ( @@ -36,7 +38,7 @@ from ndonnx._corearray import _CoreArray -class NumericOperationsImpl(UniformShapeOperations): +class _NumericOperationsImpl(OperationsBlock): # elementwise.py @validate_core @@ -739,7 +741,7 @@ def clip( and isinstance(x.dtype, dtypes.Numerical) ): x, min, max = promote(x, min, max) - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): out_null = x.null x_values = x.values._core() clipped = from_corearray(opx.clip(x_values, min._core(), max._core())) @@ -837,17 +839,6 @@ def var( - correction ) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.broadcast_to(null, nda.shape(x)), - ) - @validate_core def can_cast(self, from_, to) -> bool: if isinstance(from_, dtypes.CoreType) and isinstance(to, ndx.CoreType): @@ -856,7 +847,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -866,7 +857,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) @@ -898,7 +889,7 @@ def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Ar @validate_core def tril(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( @@ -909,7 +900,7 @@ def tril(self, x, k=0) -> ndx.Array: @validate_core def triu(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( @@ -988,9 +979,14 @@ def where(self, condition, x, y): return super().where(condition, x, y) -class NullableNumericOperationsImpl(NumericOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class NumericOperationsImpl( + CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations +): ... + + +class NullableNumericOperationsImpl( + NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations +): ... def _via_i64_f64( diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index d341adb..67da73e 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -1,7 +1,10 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING import numpy as np @@ -13,6 +16,10 @@ from ._interface import OperationsBlock from ._utils import from_corearray +if TYPE_CHECKING: + from ndonnx._array import Array, IndexType + from ndonnx._data_types import Dtype + class UniformShapeOperations(OperationsBlock): """Provides implementation for shape/indexing operations that are generic across all @@ -245,4 +252,55 @@ def zeros_like(self, x, dtype=None, device=None): return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device) def ones_like(self, x, dtype=None, device=None): - return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device) + return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device) + + def make_array( + self, + shape: tuple[int | None | str, ...], + dtype: Dtype, + eager_value: np.ndarray | None = None, + ) -> Array: + if isinstance(dtype, dtypes.CoreType): + return NotImplemented + + fields: dict[str, ndx.Array] = {} + + eager_values = None if eager_value is None else dtype._parse_input(eager_value) + for name, field_dtype in dtype._fields().items(): + if eager_values is None: + field_value = None + else: + field_value = _assemble_output_recurse(field_dtype, eager_values[name]) + fields[name] = field_dtype._ops.make_array( + shape, + field_dtype, + field_value, + ) + return ndx.Array._from_fields( + dtype, + **fields, + ) + + def getitem(self, x: Array, index: IndexType) -> Array: + if isinstance(index, ndx.Array) and not ( + isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool + ): + raise TypeError( + f"Index must be an integral or boolean 'Array', not `{index.dtype}`" + ) + + if isinstance(index, ndx.Array): + index = index._core() + + return x._transmute(lambda corearray: corearray[index]) + + +def _assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray: + if isinstance(dtype, dtypes.CoreType): + return dtype._assemble_output(values) + else: + fields = { + name: _assemble_output_recurse(field_dtype, values[name]) + for name, field_dtype in dtype._fields().items() + } + return dtype._assemble_output(fields) diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index 6f4187d..b8d7cab 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -10,8 +10,9 @@ import ndonnx as ndx import ndonnx._data_types as dtypes import ndonnx._opset_extensions as opx -import ndonnx.additional as nda +from ._coreimpl import CoreOperationsImpl +from ._interface import OperationsBlock from ._nullableimpl import NullableOperationsImpl from ._shapeimpl import UniformShapeOperations from ._utils import binary_op, validate_core @@ -20,7 +21,7 @@ from ndonnx import Array -class StringOperationsImpl(UniformShapeOperations): +class _StringOperationsImpl(OperationsBlock): @validate_core def add(self, x, y) -> Array: return binary_op(x, y, opx.string_concat) @@ -53,7 +54,7 @@ def zeros_like( self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None ): if dtype is not None and not isinstance( - dtype, (dtypes.CoreType, dtypes._NullableCore) + dtype, (dtypes.CoreType, dtypes.NullableCore) ): raise TypeError("'dtype' must be a CoreType or NullableCoreType") if dtype in (None, dtypes.utf8, dtypes.nutf8): @@ -69,17 +70,6 @@ def empty(self, shape, dtype=None, device=None) -> ndx.Array: def empty_like(self, x, dtype=None, device=None) -> ndx.Array: return ndx.zeros_like(x, dtype=dtype, device=device) - @validate_core - def make_nullable(self, x, null): - if null.dtype != dtypes.bool: - raise TypeError("'null' must be a boolean array") - - return ndx.Array._from_fields( - dtypes.into_nullable(x.dtype), - values=x.copy(), - null=ndx.broadcast_to(null, nda.shape(x)), - ) - @validate_core def where(self, condition, x, y): if x.dtype != y.dtype: @@ -89,6 +79,11 @@ def where(self, condition, x, y): return super().where(condition, x, y) -class NullableStringOperationsImpl(StringOperationsImpl, NullableOperationsImpl): - def make_nullable(self, x, null): - return NotImplemented +class StringOperationsImpl( + CoreOperationsImpl, _StringOperationsImpl, UniformShapeOperations +): ... + + +class NullableStringOperationsImpl( + NullableOperationsImpl, _StringOperationsImpl, UniformShapeOperations +): ... diff --git a/ndonnx/_core/_utils.py b/ndonnx/_core/_utils.py index ec84943..29e1b74 100644 --- a/ndonnx/_core/_utils.py +++ b/ndonnx/_core/_utils.py @@ -38,7 +38,7 @@ def variadic_op( ): args = promote(*args) out_dtype = args[0].dtype - if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)): + if not isinstance(out_dtype, (dtypes.CoreType, dtypes.NullableCore)): raise TypeError( f"Expected ndx.Array with CoreType or NullableCoreType, got {args[0].dtype}" ) @@ -100,7 +100,7 @@ def _via_dtype( promoted = promote(*arrays) out_dtype = promoted[0].dtype - if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype: + if isinstance(out_dtype, dtypes.NullableCore) and out_dtype.values == dtype: dtype = out_dtype values, nulls = split_nulls_and_values( @@ -203,7 +203,7 @@ def validate_core(func): def wrapper(*args, **kwargs): for arg in itertools.chain(args, kwargs.values()): if isinstance(arg, ndx.Array) and not isinstance( - arg.dtype, (dtypes.CoreType, dtypes._NullableCore) + arg.dtype, (dtypes.CoreType, dtypes.NullableCore) ): return NotImplemented return func(*args, **kwargs) diff --git a/ndonnx/_data_types/__init__.py b/ndonnx/_data_types/__init__.py index 12d580d..392abe0 100644 --- a/ndonnx/_data_types/__init__.py +++ b/ndonnx/_data_types/__init__.py @@ -3,7 +3,7 @@ from __future__ import annotations from ndonnx._utility import deprecated - +from typing import Union from .aliases import ( bool, float32, @@ -40,7 +40,7 @@ NullableUnsigned, Numerical, Unsigned, - _NullableCore, + NullableCore, from_numpy_dtype, get_finfo, get_iinfo, @@ -51,7 +51,7 @@ from .structtype import StructType -def into_nullable(dtype: StructType | CoreType) -> _NullableCore: +def into_nullable(dtype: StructType | CoreType) -> NullableCore: """Return nullable counterpart, if present. Parameters @@ -61,7 +61,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: Returns ------- - out : _NullableCore + out : NullableCore The nullable counterpart of the input type. Raises @@ -93,24 +93,27 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: return nuint64 elif dtype == utf8: return nutf8 - elif isinstance(dtype, _NullableCore): + elif isinstance(dtype, NullableCore): return dtype else: raise ValueError(f"Cannot promote {dtype} to nullable") +Dtype = Union[CoreType, StructType] + + @deprecated( "Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. " "To create nullable array, use 'ndonnx.additional.make_nullable' instead." ) -def promote_nullable(dtype: StructType | CoreType) -> _NullableCore: +def promote_nullable(dtype: StructType | CoreType) -> NullableCore: return into_nullable(dtype) __all__ = [ "CoreType", "StructType", - "_NullableCore", + "NullableCore", "NullableFloating", "NullableIntegral", "NullableUnsigned", @@ -151,4 +154,5 @@ def promote_nullable(dtype: StructType | CoreType) -> _NullableCore: "Schema", "CastMixin", "CastError", + "Dtype", ] diff --git a/ndonnx/_data_types/classes.py b/ndonnx/_data_types/classes.py index 661ef20..c8acd7e 100644 --- a/ndonnx/_data_types/classes.py +++ b/ndonnx/_data_types/classes.py @@ -189,7 +189,7 @@ def _fields(self) -> dict[str, StructType | CoreType]: } -class _NullableCore(Nullable[CoreType], CastMixin): +class NullableCore(Nullable[CoreType], CastMixin): def copy(self) -> Self: return self @@ -213,7 +213,7 @@ def _schema(self) -> Schema: return Schema(type_name=type(self).__name__, author="ndonnx") def _cast_to(self, array: Array, dtype: CoreType | StructType) -> Array: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): return ndx.Array._from_fields( dtype, values=self.values._cast_to(array.values, dtype.values), @@ -230,7 +230,7 @@ def _cast_from(self, array: Array) -> Array: values=self.values._cast_from(array), null=ndx.zeros_like(array, dtype=Boolean()), ) - elif isinstance(array.dtype, _NullableCore): + elif isinstance(array.dtype, NullableCore): return ndx.Array._from_fields( self, values=self.values._cast_from(array.values), @@ -240,7 +240,7 @@ def _cast_from(self, array: Array) -> Array: raise CastError(f"Cannot cast from {array.dtype} to {self}") -class NullableNumerical(_NullableCore): +class NullableNumerical(NullableCore): """Base class for nullable numerical data types.""" _ops: OperationsBlock = NullableNumericOperationsImpl() @@ -312,14 +312,14 @@ class NFloat64(NullableFloating): null = Boolean() -class NBoolean(_NullableCore): +class NBoolean(NullableCore): values = Boolean() null = Boolean() _ops: OperationsBlock = NullableBooleanOperationsImpl() -class NUtf8(_NullableCore): +class NUtf8(NullableCore): values = Utf8() null = Boolean() @@ -405,18 +405,18 @@ def _from_dtype(cls, dtype: CoreType) -> Finfo: ) -def get_finfo(dtype: _NullableCore | CoreType) -> Finfo: +def get_finfo(dtype: NullableCore | CoreType) -> Finfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Finfo._from_dtype(dtype) except KeyError: raise TypeError(f"'{dtype}' is not a floating point data type.") -def get_iinfo(dtype: _NullableCore | CoreType) -> Iinfo: +def get_iinfo(dtype: NullableCore | CoreType) -> Iinfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Iinfo._from_dtype(dtype) except KeyError: diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 92f5bb1..d15dd16 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -11,7 +11,7 @@ import numpy.typing as npt import ndonnx._data_types as dtypes -from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore +from ndonnx._data_types import CastError, CastMixin, CoreType, NullableCore from ndonnx._data_types.structtype import StructType from ndonnx.additional import shape @@ -61,20 +61,26 @@ def asarray( device=None, ) -> Array: if not isinstance(x, Array): - arr = np.asanyarray( + eager_value = np.asanyarray( x, dtype=( dtype.to_numpy_dtype() if isinstance(dtype, dtypes.CoreType) else None ), ) if dtype is None: - dtype = dtypes.from_numpy_dtype(arr.dtype) - if isinstance(arr, np.ma.masked_array): + dtype = dtypes.from_numpy_dtype(eager_value.dtype) + if isinstance(eager_value, np.ma.masked_array): dtype = dtypes.into_nullable(dtype) - ret = Array._construct( - shape=arr.shape, dtype=dtype, eager_values=dtype._parse_input(arr) + ret = dtype._ops.make_array( + shape=eager_value.shape, + dtype=dtype, + eager_value=eager_value, ) + if ret is NotImplemented: + raise UnsupportedOperationError( + f"Unsupported operand type for asarray: '{dtype}'" + ) else: ret = x.copy() if copy is True else x @@ -291,7 +297,7 @@ def result_type( np_dtypes = [] for dtype in observed_dtypes: if isinstance(dtype, dtypes.StructType): - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): nullable = True np_dtypes.append(dtype.values.to_numpy_dtype()) else: @@ -580,7 +586,11 @@ def numeric_like(x): def broadcast_to(x, shape): - return x.dtype._ops.broadcast_to(x, shape) + if (out := x.dtype._ops.broadcast_to(x, shape)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for broadcast_to: '{x.dtype}'" + ) # TODO: onnxruntime doesn't work for 2 empty arrays of integer type @@ -599,27 +609,47 @@ def concat(arrays, /, *, axis: int | None = 0): def expand_dims(x, axis=0): - return x.dtype._ops.expand_dims(x, axis) + if (out := x.dtype._ops.expand_dims(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for expand_dims: '{x.dtype}'" + ) def flip(x, axis=None): - return x.dtype._ops.flip(x, axis=axis) + if (out := x.dtype._ops.flip(x, axis=axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for flip: '{x.dtype}'") def permute_dims(x, axes): - return x.dtype._ops.permute_dims(x, axes) + if (out := x.dtype._ops.permute_dims(x, axes)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for permute_dims: '{x.dtype}'" + ) def reshape(x, shape, *, copy=None): - return x.dtype._ops.reshape(x, shape, copy=copy) + if (out := x.dtype._ops.reshape(x, shape, copy=copy)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for reshape: '{x.dtype}'" + ) def roll(x, shift, axis=None): - return x.dtype._ops.roll(x, shift, axis) + if (out := x.dtype._ops.roll(x, shift, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError(f"Unsupported operand type for roll: '{x.dtype}'") def squeeze(x, axis): - return x.dtype._ops.squeeze(x, axis) + if (out := x.dtype._ops.squeeze(x, axis)) is not NotImplemented: + return out + raise UnsupportedOperationError( + f"Unsupported operand type for squeeze: '{x.dtype}'" + ) def stack(arrays, axis=0): diff --git a/ndonnx/additional/_additional.py b/ndonnx/additional/_additional.py index 9c55764..a9f1fc4 100644 --- a/ndonnx/additional/_additional.py +++ b/ndonnx/additional/_additional.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ndonnx import Array + from ndonnx._array import IndexType Scalar = TypeVar("Scalar", int, float, str) @@ -149,6 +150,15 @@ def make_nullable(x: Array, null: Array) -> Array: return out +def _getitem(x: Array, index: IndexType) -> ndx.Array: + out = x.dtype._ops.getitem(x, index) + if out is NotImplemented: + raise ndx.UnsupportedOperationError( + f"'getitem' not implemented for `{x.dtype}`" + ) + return out + + def _static_shape(x: Array) -> tuple[int | None, ...]: """Return shape of the array as a tuple. Typical implementations will make use of ONNX shape inference, with `None` entries denoting unknown or symbolic dimensions. diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 14252d8..cddc3b0 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -3,6 +3,7 @@ from __future__ import annotations +import functools import re import numpy as np @@ -10,12 +11,19 @@ from typing_extensions import Self import ndonnx as ndx +import ndonnx.additional as nda from ndonnx import ( Array, CastError, CoreType, ) -from ndonnx._experimental import CastMixin, Schema, StructType, UniformShapeOperations +from ndonnx._experimental import ( + CastMixin, + OperationsBlock, + Schema, + StructType, + UniformShapeOperations, +) from .utils import assert_array_equal @@ -137,6 +145,94 @@ def _cast_from(self, array: Array) -> Array: _ops = Unsigned96Impl() +class ListImpl(OperationsBlock): + def make_array( + self, + shape: tuple[int | str | None, ...], + dtype: CoreType | StructType, + eager_value: np.ndarray | None = None, + ) -> Array: + if eager_value is None: + return Array._from_fields( + dtype, + endpoints=ndx.array(shape=shape + (2,), dtype=ndx.int64), + items=ndx.array(shape=(None,), dtype=ndx.utf8), + ) + else: + fields = dtype._parse_input(eager_value) + return Array._from_fields( + dtype, **{name: ndx.asarray(field) for name, field in fields.items()} + ) + + def getitem( + self, + x: Array, + index, + ) -> Array: + if isinstance(index, int): + index = slice(index, index + 1), ... + + return Array._from_fields( + dtype=x.dtype, + endpoints=x.endpoints[index], + items=x.items.copy(), + ) + + def shape(self, x) -> Array: + return nda.shape(x.endpoints)[:-1] + + def static_shape(self, x) -> tuple[int | None, ...]: + return x.endpoints.shape[:-1] + + +class List(StructType): + # The fields here have different shapes + def _fields(self) -> dict[str, StructType | CoreType]: + return { + "endpoints": ndx.int64, + "items": ndx.utf8, + } + + def _parse_input(self, x: np.ndarray) -> dict: + assert x.dtype == object + assert all(isinstance(x, list) for x in x.flat) + + endpoints = np.empty(x.shape + (2,), dtype=np.int64) + items = np.empty( + functools.reduce(lambda acc, elem: acc + len(elem), x.flat, 0), dtype=object + ) + + cur_items_idx = 0 + for idx in np.ndindex(x.shape): + endpoints[idx, :] = [cur_items_idx, cur_items_idx + len(x[idx])] + for elem in x[idx]: + items[cur_items_idx] = elem + cur_items_idx += 1 + + return { + "endpoints": endpoints, + "items": items.astype(np.str_), + } + + def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray: + endpoints = fields["endpoints"] + items = fields["items"] + + out = np.empty(endpoints.shape[:-1], dtype=object) + for idx in np.ndindex(endpoints.shape[:-1]): + start, end = endpoints[idx] + out[idx] = items[start:end].tolist() + return out + + def copy(self) -> Self: + return self + + def _schema(self) -> Schema: + return Schema(type_name="List", author="value from data!") + + _ops = ListImpl() + + def custom_equal(x: Array, y: Array) -> Array: if x.dtype != Unsigned96() or y.dtype != Unsigned96(): raise ValueError("Can only compare Unsigned96 arrays") @@ -292,3 +388,49 @@ def test_custom_where(u96): result2 = ndx.where(cond, y, x) assert_array_equal(result2, ndx.asarray([4, 2, 6], u96)) + + result3 = ndx.where(cond, x, ndx.asarray(0, ndx.uint32)) + assert_array_equal(result3, ndx.asarray([1, 0, 3], u96)) + + +def test_create_dtype_mismatched_shape_fields_eager(): + array = np.empty(shape=(2,), dtype=object) + array[0] = ["a", "bcd", "e"] + array[1] = ["f", "gh"] + x = ndx.asarray(array, dtype=List()) + assert_array_equal(x.to_numpy(), array) + assert x[0].to_numpy().item() == ["a", "bcd", "e"] + assert_array_equal(nda.shape(x).to_numpy(), np.array([2], dtype=np.int64)) + assert x.shape == (2,) + + +def test_create_dtype_mismatched_shape_fields_lazy(): + x = ndx.array(shape=("N", "M", 2), dtype=List()) + assert x.shape == (None, None, 2) + out = x[1:2, 0, ...] + + ndx.build({"x": x}, {"out": out}) + + +def test_recursive_construction(): + class MyNInt64(StructType): + def _fields(self) -> dict[str, StructType | CoreType]: + return {"x": ndx.nint64} + + def _parse_input(self, x: np.ndarray) -> dict: + return {"x": ndx.nint64._parse_input(x)} + + def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray: + return fields["x"] + + def copy(self) -> Self: + return self + + def _schema(self) -> Schema: + return Schema(type_name="my_nint64", author="me") + + _ops = UniformShapeOperations() + + my_nint64 = MyNInt64() + a = ndx.asarray(np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64), my_nint64) + assert_array_equal(a.to_numpy(), np.ma.masked_array([1, 2, 3], [1, 0, 1], np.int64))