From c132f9366fc9849c6c842d7d5eea33c216b2b317 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Thu, 29 Aug 2024 11:53:52 +0200 Subject: [PATCH] Make `NullableCore` public (#69) --- CHANGELOG.rst | 4 ++++ ndonnx/__init__.py | 2 ++ ndonnx/_core/_boolimpl.py | 4 ++-- ndonnx/_core/_numericimpl.py | 10 +++++----- ndonnx/_core/_stringimpl.py | 2 +- ndonnx/_core/_utils.py | 6 +++--- ndonnx/_data_types/__init__.py | 12 ++++++------ ndonnx/_data_types/classes.py | 20 ++++++++++---------- ndonnx/_funcs.py | 4 ++-- 9 files changed, 35 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9de8604..194dae8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,10 @@ Changelog 0.9.0 (unreleased) ------------------ +**New feature** + +- :class:`ndonnx.NullableCore` is now public, encapsulating nullable variants of `CoreType`s exported by ndonnx. + **Bug fixes** - Various operations that depend on the array's shape have been updated to work correctly with lazy arrays. diff --git a/ndonnx/__init__.py b/ndonnx/__init__.py index 358b4b3..9587f99 100644 --- a/ndonnx/__init__.py +++ b/ndonnx/__init__.py @@ -14,6 +14,7 @@ Floating, Integral, Nullable, + NullableCore, NullableFloating, NullableIntegral, NullableNumerical, @@ -323,6 +324,7 @@ "Floating", "NullableIntegral", "Nullable", + "NullableCore", "Integral", "CoreType", "CastError", diff --git a/ndonnx/_core/_boolimpl.py b/ndonnx/_core/_boolimpl.py index bf78657..e778f62 100644 --- a/ndonnx/_core/_boolimpl.py +++ b/ndonnx/_core/_boolimpl.py @@ -100,7 +100,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -110,7 +110,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index a100e7d..ecfdcd0 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -739,7 +739,7 @@ def clip( and isinstance(x.dtype, dtypes.Numerical) ): x, min, max = promote(x, min, max) - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): out_null = x.null x_values = x.values._core() clipped = from_corearray(opx.clip(x_values, min._core(), max._core())) @@ -856,7 +856,7 @@ def can_cast(self, from_, to) -> bool: @validate_core def all(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, True, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(True, dtype=ndx.bool) @@ -866,7 +866,7 @@ def all(self, x, *, axis=None, keepdims: bool = False): @validate_core def any(self, x, *, axis=None, keepdims: bool = False): - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): x = ndx.where(x.null, False, x.values) if functools.reduce(operator.mul, x._static_shape, 1) == 0: return ndx.asarray(False, dtype=ndx.bool) @@ -898,7 +898,7 @@ def arange(self, start, stop=None, step=None, dtype=None, device=None) -> ndx.Ar @validate_core def tril(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( @@ -909,7 +909,7 @@ def tril(self, x, k=0) -> ndx.Array: @validate_core def triu(self, x, k=0) -> ndx.Array: - if isinstance(x.dtype, dtypes._NullableCore): + if isinstance(x.dtype, dtypes.NullableCore): # NumPy appears to just ignore the mask so we do the same x = x.values return x._transmute( diff --git a/ndonnx/_core/_stringimpl.py b/ndonnx/_core/_stringimpl.py index e71cd83..2a42152 100644 --- a/ndonnx/_core/_stringimpl.py +++ b/ndonnx/_core/_stringimpl.py @@ -53,7 +53,7 @@ def zeros_like( self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None ): if dtype is not None and not isinstance( - dtype, (dtypes.CoreType, dtypes._NullableCore) + dtype, (dtypes.CoreType, dtypes.NullableCore) ): raise TypeError("'dtype' must be a CoreType or NullableCoreType") if dtype in (None, dtypes.utf8, dtypes.nutf8): diff --git a/ndonnx/_core/_utils.py b/ndonnx/_core/_utils.py index ec84943..29e1b74 100644 --- a/ndonnx/_core/_utils.py +++ b/ndonnx/_core/_utils.py @@ -38,7 +38,7 @@ def variadic_op( ): args = promote(*args) out_dtype = args[0].dtype - if not isinstance(out_dtype, (dtypes.CoreType, dtypes._NullableCore)): + if not isinstance(out_dtype, (dtypes.CoreType, dtypes.NullableCore)): raise TypeError( f"Expected ndx.Array with CoreType or NullableCoreType, got {args[0].dtype}" ) @@ -100,7 +100,7 @@ def _via_dtype( promoted = promote(*arrays) out_dtype = promoted[0].dtype - if isinstance(out_dtype, dtypes._NullableCore) and out_dtype.values == dtype: + if isinstance(out_dtype, dtypes.NullableCore) and out_dtype.values == dtype: dtype = out_dtype values, nulls = split_nulls_and_values( @@ -203,7 +203,7 @@ def validate_core(func): def wrapper(*args, **kwargs): for arg in itertools.chain(args, kwargs.values()): if isinstance(arg, ndx.Array) and not isinstance( - arg.dtype, (dtypes.CoreType, dtypes._NullableCore) + arg.dtype, (dtypes.CoreType, dtypes.NullableCore) ): return NotImplemented return func(*args, **kwargs) diff --git a/ndonnx/_data_types/__init__.py b/ndonnx/_data_types/__init__.py index 12d580d..83fe3f5 100644 --- a/ndonnx/_data_types/__init__.py +++ b/ndonnx/_data_types/__init__.py @@ -40,7 +40,7 @@ NullableUnsigned, Numerical, Unsigned, - _NullableCore, + NullableCore, from_numpy_dtype, get_finfo, get_iinfo, @@ -51,7 +51,7 @@ from .structtype import StructType -def into_nullable(dtype: StructType | CoreType) -> _NullableCore: +def into_nullable(dtype: StructType | CoreType) -> NullableCore: """Return nullable counterpart, if present. Parameters @@ -61,7 +61,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: Returns ------- - out : _NullableCore + out : NullableCore The nullable counterpart of the input type. Raises @@ -93,7 +93,7 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: return nuint64 elif dtype == utf8: return nutf8 - elif isinstance(dtype, _NullableCore): + elif isinstance(dtype, NullableCore): return dtype else: raise ValueError(f"Cannot promote {dtype} to nullable") @@ -103,14 +103,14 @@ def into_nullable(dtype: StructType | CoreType) -> _NullableCore: "Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. " "To create nullable array, use 'ndonnx.additional.make_nullable' instead." ) -def promote_nullable(dtype: StructType | CoreType) -> _NullableCore: +def promote_nullable(dtype: StructType | CoreType) -> NullableCore: return into_nullable(dtype) __all__ = [ "CoreType", "StructType", - "_NullableCore", + "NullableCore", "NullableFloating", "NullableIntegral", "NullableUnsigned", diff --git a/ndonnx/_data_types/classes.py b/ndonnx/_data_types/classes.py index 661ef20..c8acd7e 100644 --- a/ndonnx/_data_types/classes.py +++ b/ndonnx/_data_types/classes.py @@ -189,7 +189,7 @@ def _fields(self) -> dict[str, StructType | CoreType]: } -class _NullableCore(Nullable[CoreType], CastMixin): +class NullableCore(Nullable[CoreType], CastMixin): def copy(self) -> Self: return self @@ -213,7 +213,7 @@ def _schema(self) -> Schema: return Schema(type_name=type(self).__name__, author="ndonnx") def _cast_to(self, array: Array, dtype: CoreType | StructType) -> Array: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): return ndx.Array._from_fields( dtype, values=self.values._cast_to(array.values, dtype.values), @@ -230,7 +230,7 @@ def _cast_from(self, array: Array) -> Array: values=self.values._cast_from(array), null=ndx.zeros_like(array, dtype=Boolean()), ) - elif isinstance(array.dtype, _NullableCore): + elif isinstance(array.dtype, NullableCore): return ndx.Array._from_fields( self, values=self.values._cast_from(array.values), @@ -240,7 +240,7 @@ def _cast_from(self, array: Array) -> Array: raise CastError(f"Cannot cast from {array.dtype} to {self}") -class NullableNumerical(_NullableCore): +class NullableNumerical(NullableCore): """Base class for nullable numerical data types.""" _ops: OperationsBlock = NullableNumericOperationsImpl() @@ -312,14 +312,14 @@ class NFloat64(NullableFloating): null = Boolean() -class NBoolean(_NullableCore): +class NBoolean(NullableCore): values = Boolean() null = Boolean() _ops: OperationsBlock = NullableBooleanOperationsImpl() -class NUtf8(_NullableCore): +class NUtf8(NullableCore): values = Utf8() null = Boolean() @@ -405,18 +405,18 @@ def _from_dtype(cls, dtype: CoreType) -> Finfo: ) -def get_finfo(dtype: _NullableCore | CoreType) -> Finfo: +def get_finfo(dtype: NullableCore | CoreType) -> Finfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Finfo._from_dtype(dtype) except KeyError: raise TypeError(f"'{dtype}' is not a floating point data type.") -def get_iinfo(dtype: _NullableCore | CoreType) -> Iinfo: +def get_iinfo(dtype: NullableCore | CoreType) -> Iinfo: try: - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): dtype = dtype.values return Iinfo._from_dtype(dtype) except KeyError: diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index 92f5bb1..d8f8b54 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -11,7 +11,7 @@ import numpy.typing as npt import ndonnx._data_types as dtypes -from ndonnx._data_types import CastError, CastMixin, CoreType, _NullableCore +from ndonnx._data_types import CastError, CastMixin, CoreType, NullableCore from ndonnx._data_types.structtype import StructType from ndonnx.additional import shape @@ -291,7 +291,7 @@ def result_type( np_dtypes = [] for dtype in observed_dtypes: if isinstance(dtype, dtypes.StructType): - if isinstance(dtype, _NullableCore): + if isinstance(dtype, NullableCore): nullable = True np_dtypes.append(dtype.values.to_numpy_dtype()) else: