diff --git a/ndonnx/_logic_in_data/__init__.py b/ndonnx/_logic_in_data/__init__.py index dda4903..8c9cb70 100644 --- a/ndonnx/_logic_in_data/__init__.py +++ b/ndonnx/_logic_in_data/__init__.py @@ -19,15 +19,23 @@ ) from .funcs import ( arange, + empty, + empty_like, ones, + ones_like, + full, + full_like, finfo, iinfo, - zeros, reshape, all, isfinite, isnan, equal, + zeros, + zeros_like, + linspace, + where, ) __all__ = [ @@ -45,15 +53,24 @@ "float64", "bool", "DType", + "all", "arange", - "ones", + "asarray", + "empty", + "empty_like", + "equal", "finfo", + "full", + "full_like", "iinfo", - "zeros", - "reshape", - "asarray", - "all", "isfinite", "isnan", - "equal", + "linspace", + "ones", + "ones_like", + "reshape", + "where", + "zeros", + "zeros", + "zeros_like", ] diff --git a/ndonnx/_logic_in_data/_typed_array/core.py b/ndonnx/_logic_in_data/_typed_array/core.py index b727471..16218ef 100644 --- a/ndonnx/_logic_in_data/_typed_array/core.py +++ b/ndonnx/_logic_in_data/_typed_array/core.py @@ -30,6 +30,57 @@ ) +class _Index: + starts: list[int] + ends: list[int] + steps: list[int] + axes: list[int] + squeeze_axes: list[int] + + def __init__(self, index: Index): + if isinstance(index, tuple): + index_ = index + elif isinstance(index, int | slice): + index_ = (index,) + else: + raise NotImplementedError + self.starts = [] + self.ends = [] + self.steps = [] + self.axes = [] + self.squeeze_axes = [] + + def compute_end_slice(stop: int | None, step: int | None) -> int: + if isinstance(stop, int): + return stop + step = step or 1 + # Iterate "to the end" + if step < 1: + return int(np.iinfo(np.int64).min) + return int(np.iinfo(np.int64).max) + + def compute_end_single_idx(start: int): + end = start + 1 + if end == 0: + return np.iinfo(np.int64).max + return end + + for i, el in enumerate(index_): + if isinstance(el, slice): + self.starts.append(el.start or 0) + self.ends.append(compute_end_slice(el.stop, el.step)) + self.axes.append(i) + self.steps.append(el.step or 1) + elif isinstance(el, int): + self.starts.append(el) + self.ends.append(compute_end_single_idx(el)) + self.axes.append(i) + self.steps.append(1) + self.squeeze_axes.append(i) + else: + raise NotImplementedError + + class TyArray(TyArrayBase): dtype: dtypes.CoreDTypes var: Var @@ -40,27 +91,18 @@ def __init__(self, var: Var): self.var = var def __getitem__(self, index: Index) -> Self: - if isinstance(index, int): - var = op.slice( - self.var, - starts=op.const([index]), - ends=op.const([index + 1]), - axes=op.const([0]), - ) - var = op.squeeze(var, axes=op.const([0])) - return type(self)(var) - if isinstance(index, tuple) and all_items_are_int(index): - starts, ends, axes = zip(*[(el, el + 1, i) for i, el in enumerate(index)]) - var = op.slice( - self.var, - starts=op.const(starts), - ends=op.const(ends), - axes=op.const(axes), - ) - var = op.squeeze(var, axes=op.const(axes)) - return type(self)(var) - - raise NotImplementedError + if index == (): + return self + parsed = _Index(index) + var = op.slice( + self.var, + starts=op.const(parsed.starts), + ends=op.const(parsed.ends), + axes=op.const(parsed.axes), + steps=op.const(parsed.steps), + ) + var = op.squeeze(var, axes=op.const(parsed.squeeze_axes, np.int64)) + return type(self)(var) @property def shape(self) -> OnnxShape: @@ -69,6 +111,11 @@ def shape(self) -> OnnxShape: raise ValueError("Missing shape information") return shape + @property + def dynamic_shape(self) -> TyArrayInt64: + var = op.shape(self.var) + return TyArrayInt64(var) + def to_numpy(self) -> np.ndarray: if self.var._value is not None: np_arr = np.asarray(self.var._value.value) @@ -109,8 +156,12 @@ def reshape(self, shape: tuple[int, ...]) -> Self: var = op.reshape(self.var, op.const(shape), allowzero=True) return type(self)(var) - def broadcast_to(self, shape: tuple[int, ...]) -> Self: - var = op.expand(self.var, op.const(shape, dtype=np.int64)) + def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self: + if isinstance(shape, tuple): + shape_var = op.const(shape, dtype=np.int64) + else: + shape_var = shape.var + var = op.expand(self.var, shape_var) return type(self)(var) def as_core_dtype(self, dtype: CoreDTypes) -> TyArray: diff --git a/ndonnx/_logic_in_data/_typed_array/date_time.py b/ndonnx/_logic_in_data/_typed_array/date_time.py index 6acf6ef..37ee62b 100644 --- a/ndonnx/_logic_in_data/_typed_array/date_time.py +++ b/ndonnx/_logic_in_data/_typed_array/date_time.py @@ -136,13 +136,17 @@ def __getitem__(self, index: Index) -> Self: def shape(self) -> OnnxShape: return self.data.shape + @property + def dynamic_shape(self) -> TyArrayInt64: + return self.data.dynamic_shape + def reshape(self, shape: tuple[int, ...]) -> Self: is_nat = self.is_nat.reshape(shape) data = self.data.reshape(shape) return type(self)(is_nat=is_nat, data=data, unit=self.dtype.unit) - def broadcast_to(self, shape: tuple[int, ...]) -> Self: + def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self: data = self.data.broadcast_to(shape) is_nat = self.is_nat.broadcast_to(shape) return type(self)(data=data, is_nat=is_nat, unit=self.dtype.unit) diff --git a/ndonnx/_logic_in_data/_typed_array/masked.py b/ndonnx/_logic_in_data/_typed_array/masked.py index 3972cb3..5d857d3 100644 --- a/ndonnx/_logic_in_data/_typed_array/masked.py +++ b/ndonnx/_logic_in_data/_typed_array/masked.py @@ -13,7 +13,7 @@ from .. import dtypes from ..dtypes import CoreDTypes, DType, NCoreDTypes from ..schema import Schema, flatten_components -from .core import TyArray, TyArrayBool +from .core import TyArray, TyArrayBool, TyArrayInt64 from .typed_array import TyArrayBase if TYPE_CHECKING: @@ -82,17 +82,18 @@ def __init__(self, data: TyArray, mask: TyArrayBool | None): @property def shape(self) -> OnnxShape: - shape = self.data.shape - if shape is None: - raise ValueError("Missing shape information") - return shape + return self.data.shape + + @property + def dynamic_shape(self) -> TyArrayInt64: + return self.data.dynamic_shape def reshape(self, shape: tuple[int, ...]) -> Self: data = self.data.reshape(shape) mask = self.mask.reshape(shape) if self.mask is not None else None return type(self)(data=data, mask=mask) - def broadcast_to(self, shape: tuple[int, ...]) -> Self: + def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self: data = self.data.broadcast_to(shape) mask = self.mask.broadcast_to(shape) if self.mask else None return type(self)(data=data, mask=mask) diff --git a/ndonnx/_logic_in_data/_typed_array/py_scalars.py b/ndonnx/_logic_in_data/_typed_array/py_scalars.py index 7bd8fcc..3c5f3eb 100644 --- a/ndonnx/_logic_in_data/_typed_array/py_scalars.py +++ b/ndonnx/_logic_in_data/_typed_array/py_scalars.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: from ..array import OnnxShape from ..schema import Components, Schema - from .core import TyArrayBool + from .core import TyArrayBool, TyArrayInt64 class _ArrayPyScalar(TyArrayBase): @@ -54,10 +54,14 @@ def ndim(self) -> int: def shape(self) -> OnnxShape: return () + @property + def dynamic_shape(self) -> TyArrayInt64: + raise ValueError("'dynamic_shape' should never be called on Python scalar") + def reshape(self, shape: tuple[int, ...]) -> Self: raise ValueError("cannot reshape Python scalar") - def broadcast_to(self, shape: tuple[int, ...]) -> Self: + def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self: raise ValueError("cannot broadcast Python scalar") def __add__(self, rhs: TyArrayBase) -> TyArrayBase: diff --git a/ndonnx/_logic_in_data/_typed_array/typed_array.py b/ndonnx/_logic_in_data/_typed_array/typed_array.py index 6c7012e..f996c15 100644 --- a/ndonnx/_logic_in_data/_typed_array/typed_array.py +++ b/ndonnx/_logic_in_data/_typed_array/typed_array.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from ..array import Index, OnnxShape from ..schema import Components, Schema - from .core import TyArray, TyArrayBool + from .core import TyArray, TyArrayBool, TyArrayInt64 class TyArrayBase(ABC): @@ -34,6 +34,10 @@ def __getitem__(self, index: Index) -> Self: ... @abstractmethod def shape(self) -> OnnxShape: ... + @property + @abstractmethod + def dynamic_shape(self) -> TyArrayInt64: ... + @abstractmethod def reshape(self, shape: tuple[int, ...]) -> Self: ... @@ -81,7 +85,7 @@ def all(self) -> TyArrayBase: raise ValueError(f"'all' is not implemented for `{self.dtype}`") @abstractmethod - def broadcast_to(self, shape: tuple[int, ...]) -> Self: ... + def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self: ... def isnan(self) -> TyArrayBase: raise ValueError(f"'isnan' is not implemented for {self.dtype}") diff --git a/ndonnx/_logic_in_data/array.py b/ndonnx/_logic_in_data/array.py index b391132..148e3f2 100644 --- a/ndonnx/_logic_in_data/array.py +++ b/ndonnx/_logic_in_data/array.py @@ -76,10 +76,19 @@ def _from_data(cls, data: TyArrayBase) -> Array: return inst @property - def shape(self) -> StandardShape: + def shape(self) -> tuple[int | None, ...]: shape = self._data.shape return tuple(None if isinstance(item, str) else item for item in shape) + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def dynamic_shape(self) -> Array: + shape = self._data.dynamic_shape + return Array._from_data(shape) + @property def dtype(self) -> DType: return self._data.dtype @@ -237,13 +246,6 @@ def asarray( return Array._from_data(data) -def where(cond: Array, a: Array, b: Array) -> Array: - from ._typed_array.funcs import typed_where - - data = typed_where(cond._data, a._data, b._data) - return Array._from_data(data) - - def add(a: Array, b: Array) -> Array: return a + b diff --git a/ndonnx/_logic_in_data/funcs.py b/ndonnx/_logic_in_data/funcs.py index 9e48b09..862b31e 100644 --- a/ndonnx/_logic_in_data/funcs.py +++ b/ndonnx/_logic_in_data/funcs.py @@ -17,14 +17,6 @@ def all( return Array._from_data(x._data.all()) -def isnan(array: Array, /) -> Array: - return Array._from_data(array._data.isnan()) - - -def isfinite(array: Array, /) -> Array: - return Array._from_data(array._data.isfinite()) - - def arange( start: int | float, /, @@ -34,23 +26,46 @@ def arange( dtype: DType | None = None, device=None, ) -> Array: - raise NotImplementedError + import builtins - -def equal(x1: Array, x2: Array, /) -> Array: - return x1 == x2 + if dtype is None: + if builtins.all( + isinstance(el, int) for el in [start, stop, step] if el is not None + ): + dtype = dtypes.default_int + else: + dtype = dtypes.default_float + dtype = dtype or dtypes.default_float + if not isinstance(dtype, dtypes.CoreDTypes): + raise ValueError(f"Only core data types are supported, found `{dtype}`") + + return asarray(np.arange(start, stop, step), dtype=dtype) + + +def broadcast_to(x: Array, /, shape: tuple[int, ...] | Array) -> Array: + from ._typed_array.core import TyArrayInt64 + + if isinstance(shape, Array): + if not isinstance(shape._data, TyArrayInt64): + raise ValueError( + f"dynamic shape must be of data type int64, found `{shape.dtype}`" + ) + return Array._from_data(x._data.broadcast_to(shape._data)) + return Array._from_data(x._data.broadcast_to(shape)) -def ones( +def empty( shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None ) -> Array: - if dtype is None: - dtype = dtypes.default_float + return zeros(shape=shape, dtype=dtype) - if isinstance(shape, int): - shape = (shape,) - return full(shape, 1, dtype=dtype) +def empty_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array: + return zeros_like(x, dtype=dtype) + + +def equal(x1: Array, x2: Array, /) -> Array: + return x1 == x2 def full( @@ -62,11 +77,65 @@ def full( ) -> Array: if isinstance(shape, int): shape = (shape,) + if dtype is None: + if isinstance(fill_value, bool): + dtype = dtypes.bool_ + elif isinstance(fill_value, int): + dtype = dtypes.default_int + elif isinstance(fill_value, float): + dtype = dtypes.default_float + else: + raise TypeError(f"Unexpected 'fill_value' type `{type(fill_value)}`") return broadcast_to(asarray(fill_value, dtype=dtype), shape) -def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array: - return Array._from_data(x._data.broadcast_to(shape)) +def full_like( + x: Array, + /, + fill_value: bool | int | float, + *, + dtype: DType | None = None, + device=None, +) -> Array: + shape = x.dynamic_shape + fill = asarray(fill_value, dtype=dtype or x.dtype) + return broadcast_to(fill, shape) + + +def isnan(array: Array, /) -> Array: + return Array._from_data(array._data.isnan()) + + +def isfinite(array: Array, /) -> Array: + return Array._from_data(array._data.isfinite()) + + +def linspace( + start: int | float | complex, + stop: int | float | complex, + /, + num: int, + *, + dtype: DType | None = None, + device=None, + endpoint: bool = True, +) -> Array: + dtype = dtype or dtypes.default_float + if not isinstance(dtype, dtypes.CoreDTypes): + raise ValueError(f"Only core data types are supported, found `{dtype}`") + return asarray(np.linspace(start, stop, num=num, endpoint=endpoint), dtype=dtype) + + +def ones( + shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None +) -> Array: + dtype = dtype or dtypes.default_float + return full(shape, 1, dtype=dtype) + + +def ones_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array: + dtype = dtype or x.dtype + return full_like(x, 1, dtype=dtype) def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> Array: @@ -77,16 +146,23 @@ def reshape(x: Array, /, shape: tuple[int, ...], *, copy: bool | None = None) -> return Array._from_data(x._data.reshape(shape)) +def where(cond: Array, a: Array, b: Array) -> Array: + from ._typed_array.funcs import typed_where + + data = typed_where(cond._data, a._data, b._data) + return Array._from_data(data) + + def zeros( shape: int | tuple[int, ...], *, dtype: DType | None = None, device=None ) -> Array: - if dtype is None: - dtype = dtypes.default_float + dtype = dtype or dtypes.default_float + return full(shape, 0, dtype=dtype) - if isinstance(shape, int): - shape = (shape,) - return reshape(asarray(0).astype(dtype), shape) +def zeros_like(x: Array, /, *, dtype: DType | None = None, device=None) -> Array: + dtype = dtype or x.dtype + return full_like(x, 0, dtype=dtype) @dataclass diff --git a/tests/test_logic_in_data.py b/tests/test_logic_in_data.py index 8535c85..baf73a7 100644 --- a/tests/test_logic_in_data.py +++ b/tests/test_logic_in_data.py @@ -5,9 +5,9 @@ import numpy as np import pytest -from ndonnx._logic_in_data import Array, dtypes +from ndonnx._logic_in_data import Array, dtypes, where from ndonnx._logic_in_data._typed_array.date_time import DateTime, TimeDelta -from ndonnx._logic_in_data.array import asarray, where +from ndonnx._logic_in_data.array import asarray from ndonnx._logic_in_data.build import build from ndonnx._logic_in_data.schema import get_schemas @@ -248,12 +248,22 @@ def test_indexing_shape(): assert arr[0].shape == (None,) +@pytest.mark.parametrize("idx", [0, -1]) @pytest.mark.parametrize("np_array", [np.asarray([[1, 2]]), np.asarray([1, 2])]) -def test_indexing_value_prop_scalar_index(np_array): +def test_indexing_value_prop_scalar_index(np_array, idx): arr = asarray(np_array) - assert arr[0].shape == np_array[0].shape - assert arr[0].dtype == arr.dtype - np.testing.assert_equal(arr[0].unwrap_numpy(), np_array[0]) + assert arr[idx].shape == np_array[idx].shape + assert arr[idx].dtype == arr.dtype + np.testing.assert_equal(arr[idx].unwrap_numpy(), np_array[idx]) + + +@pytest.mark.parametrize("np_array", [np.asarray([[1, 2]]), np.asarray([1, 2])]) +def test_indexing_value_prop_scalar_slice(np_array): + arr = asarray(np_array) + idx = slice(None, 1) + assert arr[idx].shape == np_array[idx].shape + assert arr[idx].dtype == arr.dtype + np.testing.assert_equal(arr[idx].unwrap_numpy(), np_array[idx]) def test_indexing_value_prop_tuple_index():