Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into rel-0.9.0
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 29, 2024
2 parents dc3f2aa + edc260a commit 3f8bd4f
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Changelog
- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
- :func:`ndonnx.cumulative_sum` now correctly applies the ``include_initial`` parameter and works around missing onnxruntime kernels for unsigned integral types.
- :func:`ndonnx.additional.make_nullable` applies broadcasting to the provided null array (instead of reshape like it did previously). This allows writing ``make_nullable(x, False)`` to turn an array into nullable.
- User-defined data types that implement :class:`ndonnx._core.UniformShapeOperations` may now implement :func:`ndonnx.where` without requiring both data types be promotable.

**Breaking change**

Expand Down
9 changes: 2 additions & 7 deletions ndonnx/_core/_boolimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, unary_op, validate_core

if TYPE_CHECKING:
Expand Down Expand Up @@ -165,11 +164,7 @@ def nonzero(self, x) -> tuple[Array, ...]:
return ndx.nonzero(x.astype(ndx.int8))


class BooleanOperationsImpl(
CoreOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...
class BooleanOperationsImpl(CoreOperationsImpl, _BooleanOperationsImpl): ...


class NullableBooleanOperationsImpl(
NullableOperationsImpl, _BooleanOperationsImpl, UniformShapeOperations
): ...
class NullableBooleanOperationsImpl(NullableOperationsImpl, _BooleanOperationsImpl): ...
12 changes: 10 additions & 2 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
import ndonnx.additional as nda
from ndonnx._corearray import _CoreArray

from ._interface import OperationsBlock
from ._shapeimpl import UniformShapeOperations
from ._utils import validate_core

if TYPE_CHECKING:
from ndonnx._array import Array
from ndonnx._data_types import Dtype


class CoreOperationsImpl(OperationsBlock):
class CoreOperationsImpl(UniformShapeOperations):
def make_array(
self,
shape: tuple[int | None | str, ...],
Expand Down Expand Up @@ -48,3 +48,11 @@ def make_nullable(self, x: Array, null: Array) -> Array:
values=x.copy(),
null=ndx.broadcast_to(null, nda.shape(x)),
)

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)
12 changes: 10 additions & 2 deletions ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import ndonnx as ndx

from ._interface import OperationsBlock
from ._shapeimpl import UniformShapeOperations
from ._utils import validate_core

if TYPE_CHECKING:
Expand All @@ -16,7 +16,7 @@
Dtype = Union[CoreType, StructType]


class NullableOperationsImpl(OperationsBlock):
class NullableOperationsImpl(UniformShapeOperations):
@validate_core
def fill_null(self, x: Array, value) -> Array:
value = ndx.asarray(value)
Expand All @@ -27,3 +27,11 @@ def fill_null(self, x: Array, value) -> Array:
@validate_core
def make_nullable(self, x: Array, null: Array) -> Array:
return NotImplemented

@validate_core
def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return super().where(condition, x, y)
9 changes: 2 additions & 7 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import (
binary_op,
from_corearray,
Expand Down Expand Up @@ -971,14 +970,10 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.full_like(x, 0, dtype=dtype)


class NumericOperationsImpl(
CoreOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...
class NumericOperationsImpl(CoreOperationsImpl, _NumericOperationsImpl): ...


class NullableNumericOperationsImpl(
NullableOperationsImpl, _NumericOperationsImpl, UniformShapeOperations
): ...
class NullableNumericOperationsImpl(NullableOperationsImpl, _NumericOperationsImpl): ...


def _via_i64_f64(
Expand Down
4 changes: 1 addition & 3 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def full_like(self, x, fill_value, dtype=None, device=None):

def where(self, condition, x, y):
if x.dtype != y.dtype:
target_dtype = ndx.result_type(x, y)
x = ndx.astype(x, target_dtype)
y = ndx.astype(y, target_dtype)
return NotImplemented
if isinstance(condition.dtype, dtypes.Nullable) and not isinstance(
x.dtype, (dtypes.Nullable, dtypes.CoreType)
):
Expand Down
9 changes: 2 additions & 7 deletions ndonnx/_core/_stringimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ._coreimpl import CoreOperationsImpl
from ._interface import OperationsBlock
from ._nullableimpl import NullableOperationsImpl
from ._shapeimpl import UniformShapeOperations
from ._utils import binary_op, validate_core

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,11 +70,7 @@ def empty_like(self, x, dtype=None, device=None) -> ndx.Array:
return ndx.zeros_like(x, dtype=dtype, device=device)


class StringOperationsImpl(
CoreOperationsImpl, _StringOperationsImpl, UniformShapeOperations
): ...
class StringOperationsImpl(CoreOperationsImpl, _StringOperationsImpl): ...


class NullableStringOperationsImpl(
NullableOperationsImpl, _StringOperationsImpl, UniformShapeOperations
): ...
class NullableStringOperationsImpl(NullableOperationsImpl, _StringOperationsImpl): ...
20 changes: 20 additions & 0 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def add(self, x, y) -> Array:
return x + y.astype(Unsigned96())
return NotImplemented

def where(self, condition, x, y):
x = x.astype(Unsigned96())
y = y.astype(Unsigned96())
return super().where(condition, x, y)


class Unsigned96(StructType, CastMixin):
def _fields(self) -> dict[str, StructType | CoreType]:
Expand Down Expand Up @@ -373,6 +378,21 @@ def test_custom_dtype_capable_creation_functions():
)


def test_custom_where(u96):
x = ndx.asarray([1, 2, 3], u96)
y = ndx.asarray([4, 5, 6], ndx.uint32)
cond = ndx.asarray([True, False, True])

result1 = ndx.where(cond, x, y)
assert_array_equal(result1, ndx.asarray([1, 5, 3], u96))

result2 = ndx.where(cond, y, x)
assert_array_equal(result2, ndx.asarray([4, 2, 6], u96))

result3 = ndx.where(cond, x, ndx.asarray(0, ndx.uint32))
assert_array_equal(result3, ndx.asarray([1, 0, 3], u96))


def test_create_dtype_mismatched_shape_fields_eager():
array = np.empty(shape=(2,), dtype=object)
array[0] = ["a", "bcd", "e"]
Expand Down

0 comments on commit 3f8bd4f

Please sign in to comment.