Skip to content

Commit

Permalink
Implement put
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Jan 28, 2025
1 parent 9b8a11d commit d041523
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 16 deletions.
4 changes: 4 additions & 0 deletions ndonnx/_refactor/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def from_numpy(np_dtype: np.dtype) -> onnx.DTypes:
# though.
return date_time.DateTime(unit=unit) # type: ignore

if np_dtype.kind == "m":
unit = np.datetime_data(np_dtype)[0]
return date_time.TimeDelta(unit=unit) # type: ignore

# "T" i.e. "Text" is the kind used for `StringDType` in numpy >= 2
# See https://numpy.org/neps/nep-0055-string_dtype.html#python-api-for-stringdtype
if np_dtype.kind in ["U", "T"]:
Expand Down
12 changes: 12 additions & 0 deletions ndonnx/_refactor/_typed_array/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,20 @@ def __setitem__(
value: Self,
/,
) -> None:
if self.dtype != value.dtype:
TypeError(f"data type of 'value' must much array's, found `{value.dtype}`")
self.codes[key] = value.codes

def put(
self,
key: TyArrayInt64,
value: Self,
/,
) -> None:
if self.dtype != value.dtype:
TypeError(f"data type of 'value' must much array's, found `{value.dtype}`")
self.codes.put(key, value.codes)

@property
def dynamic_shape(self) -> TyArrayInt64:
return self.codes.dynamic_shape
Expand Down
24 changes: 20 additions & 4 deletions ndonnx/_refactor/_typed_array/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,21 @@ def __setitem__(
value: Self,
/,
) -> None:
raise NotImplementedError
if self.dtype != value.dtype:
TypeError(f"data type of 'value' must much array's, found `{value.dtype}`")
self.data[key] = value.data
self.is_nat[key] = value.is_nat

def put(
self,
key: onnx.TyArrayInt64,
value: Self,
/,
) -> None:
if self.dtype != value.dtype:
TypeError(f"data type of 'value' must much array's, found `{value.dtype}`")
self.data.put(key, value.data)
self.is_nat.put(key, value.is_nat)

@property
def dynamic_shape(self) -> onnx.TyArrayInt64:
Expand Down Expand Up @@ -299,7 +313,6 @@ def __init__(self, is_nat: onnx.TyArrayBool, data: onnx.TyArrayInt64, unit: Unit
def __ndx_cast_to__(
self, dtype: DType[TY_ARRAY_BASE]
) -> TY_ARRAY_BASE | NotImplementedType:
res_type = dtype._tyarr_class
if isinstance(dtype, onnx.IntegerDTypes):
data = where(self.is_nat, _NAT_SENTINEL.astype(onnx.int64), self.data)
return data.astype(dtype)
Expand All @@ -316,9 +329,12 @@ def __ndx_cast_to__(
if power > 0:
data = data * np.pow(10, power)
if power < 0:
data = data / np.pow(10, abs(power))
data = data // np.pow(10, abs(power))

data = safe_cast(onnx.TyArrayInt64, data)
# TODO: Figure out why mypy does not like the blow
return dtype._tyarr_class(is_nat=self.is_nat, data=data, unit=dtype.unit) # type: ignore

return safe_cast(res_type, data)
return NotImplemented

def unwrap_numpy(self) -> np.ndarray:
Expand Down
33 changes: 22 additions & 11 deletions ndonnx/_refactor/_typed_array/masked_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def _zeros(self, shape: tuple[int, ...] | onnx.TyArrayInt64) -> TY_MA_ARRAY_ONNX


class _NNumber(_MaOnnxDType):
_unmasked_dtype: onnx.NumericDTypes

def __ndx_result_type__(self, rhs: DType | _PyScalar) -> DType | NotImplementedType:
if isinstance(rhs, onnx.NumericDTypes | int | float):
core_result = onnx._result_type(self._unmasked_dtype, rhs)
Expand Down Expand Up @@ -412,19 +410,32 @@ def __getitem__(self, key: GetitemIndex) -> Self:

def __setitem__(self, index: SetitemIndex, value: Self) -> None:
self.data[index] = value.data
new_mask = _merge_masks(
None if self.mask is None else self.mask[index], value.mask
)
if new_mask is None:
if self.mask is None and value.mask is None:
return
if self.mask is None:
shape = self.dynamic_shape
self.mask = safe_cast(
onnx.TyArrayBool, astyarray(False).broadcast_to(shape)
# Create a new mask for self
self.mask = astyarray(False, dtype=onnx.bool_).broadcast_to(
self.dynamic_shape
)
self.mask[index] = new_mask
if value.mask is None:
self.mask[index] = astyarray(False, dtype=onnx.bool_)
else:
self.mask[index] = new_mask
self.mask[index] = value.mask

def put(
self,
key: onnx.TyArrayInt64,
value: Self,
/,
) -> None:
self.dtype._ones
self.data.put(key, value.data)
if value.mask is not None:
if self.mask is None:
self.mask = astyarray(False, dtype=onnx.bool_).broadcast_to(
self.dynamic_shape
)
self.mask.put(key, value.mask)

def __ndx_cast_to__(self, dtype: DType[TY_ARRAY_BASE]) -> TY_ARRAY_BASE:
# Implemented under the assumption that we know about `onnx`, but not py_scalars
Expand Down
13 changes: 12 additions & 1 deletion ndonnx/_refactor/_typed_array/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,18 @@ def _setitem_int_array(self, key: TyArrayInt64, value: Self) -> None:
else:
value_correct_shape = value.broadcast_to(value_shape)
self.var = op.scatter_nd(self.var, key.var, value_correct_shape.var)
return

def put(
self,
key: TyArrayInt64,
value: Self,
/,
) -> None:
data = self.reshape((-1,))
if key.ndim == 1:
key = key.reshape((-1, 1))
data._setitem_int_array(key, value)
self.var = data.reshape(self.dynamic_shape).var

@property
def dynamic_shape(self) -> TyArrayInt64:
Expand Down
9 changes: 9 additions & 0 deletions ndonnx/_refactor/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def __setitem__(
/,
) -> None: ...

@abstractmethod
def put(
self,
key: TyArrayInt64,
value: Self,
/,
) -> None:
"""Set elements with semantics identical to `numpy.put` with `mode="raise"."""

@property
@abstractmethod
def dynamic_shape(self) -> TyArrayInt64: ...
Expand Down
20 changes: 20 additions & 0 deletions ndonnx/_refactor/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,23 @@ def get_data(x: ndx.Array, /) -> ndx.Array:
if isinstance(x._tyarray, tydx.masked_onnx.TyMaArray):
return ndx.Array._from_tyarray(x._tyarray.data)
return ndx.Array._from_tyarray(x._tyarray)


def put(a: ndx.Array, indices: ndx.Array, updates: ndx.Array, /) -> None:
"""Replaces specified elements of an array with given values.
This function follows the semantics of `numpy.put` with
`mode="raises". The data types of the update array and the updates
must match. The indices must be provided as a 1D int64 array.
"""
from ._typed_array.onnx import TyArrayInt64

if not isinstance(indices._tyarray, TyArrayInt64):
raise TypeError(
f"'indices' must be provided as an int64 tensor, found `{indices.dtype}`"
)
if a.dtype != updates.dtype:
raise TypeError(
f"data types of 'a' (`{a.dtype}`) and 'updates' (`{updates.dtype}`) must match."
)
a._tyarray.put(indices._tyarray, updates._tyarray)
14 changes: 14 additions & 0 deletions tests/test_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,3 +464,17 @@ def test_isin_with_type_promotion(np_dtype):
res = ndx.extensions.isin(ndx.asarray(np_arr), test_elements)

np.testing.assert_equal(res.unwrap_numpy(), np_res, strict=True)


@pytest.mark.parametrize(
"np_dtype",
[np.int64, np.uint16, np.dtype("datetime64[s]"), np.dtype("timedelta64[s]")],
)
def test_put(np_dtype):
np_idx = np.array([0, 2], np.int64)
np_arr = np.array([1, 2, 3], dtype=np_dtype)
arr = ndx.asarray(np_arr)
idx = ndx.asarray(np_idx)

np.put(np_arr, np_idx, np.asarray(5, dtype=np_dtype))
ndx.extensions.put(arr, idx, ndx.asarray(5, dtype=dtypes.from_numpy(np_dtype)))

0 comments on commit d041523

Please sign in to comment.