Skip to content

Commit

Permalink
Fix change log and type signature
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 28, 2024
1 parent 1ab34a7 commit a2d76b5
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Changelog
**Bug fixes**

- Various operations that depend on the array's shape have been updated to work correctly with lazy arrays.
- Fixes :func:`~ndonnx.cumulative_sum` to correctly apply the ``include_initial`` parameter and workaround missing ORT kernels for unsigned integral types.
- :func:`~ndonnx.cumulative_sum` now correctly applies the ``include_initial`` parameter and works around missing onnxruntime kernels for unsigned integral types.

**Breaking change**

Expand Down
4 changes: 3 additions & 1 deletion ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,9 @@ def cumulative_sum(
if ndx.iinfo(x.dtype).bits < 64:
out = x.astype(dtypes.int64)
else:
raise ValueError(f"Cannot perform `cumulative_sum` using {x.dtype}")
raise ndx.UnsupportedOperationError(
f"Cannot perform `cumulative_sum` using {x.dtype}"
)
else:
out = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.int64))
else:
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def broadcast_to(x, shape):
# TODO: onnxruntime doesn't work for 2 empty arrays of integer type
# TODO: what is the appropriate strategy to dispatch? (iterate over the inputs and keep trying is reasonable but it can
# change the outcome based on order if poorly implemented)
def concat(arrays, axis=0):
def concat(arrays, /, *, axis: int | None = 0):
if axis is None:
arrays = [reshape(x, [-1]) for x in arrays]
axis = 0
Expand Down
4 changes: 3 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,8 @@ def test_cumulative_sum(array, axis, include_initial, dtype):


def test_no_unsafe_cumulative_sum_cast():
with pytest.raises(ValueError, match="Cannot perform `cumulative_sum`"):
with pytest.raises(
ndx.UnsupportedOperationError, match="Cannot perform `cumulative_sum`"
):
a = ndx.asarray([1, 2, 3], ndx.uint64)
ndx.cumulative_sum(a)

0 comments on commit a2d76b5

Please sign in to comment.